CNN で輪郭の検出

画像内の物体の輪郭検出を CNN(畳み込みニューラルネット)で試してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190114/

概要

今回は、画像をピクセル単位で輪郭か否かに分類する事(輪郭 = 1, 輪郭以外 = 0)で輪郭を検出できないか試しました。

そこで、教師データとして以下のような衣服単体の画像(jpg)と衣服の輪郭部分だけを白く塗りつぶした画像(png)を用意しました。

f:id:fits:20190114225837j:plain

教師データを大量に用意するのは困難だったため、240x288 の画像 160 ファイルで学習を行っています。

学習

学習の処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、入力画像(jpg)を読み込みます。(教師データの画像は img ディレクトリへ配置しています)

import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

files = glob.glob('img/*.jpg')

imgs = np.array([img_to_array(load_img(f)) for f in files])

imgs.shape

入力データの形状は以下の通りです。

imgs.shape 結果
(160, 288, 240, 3)

(2) ラベルデータの準備

輪郭画像(png)を読み込み、128 を境にして二値化(輪郭 = 1、輪郭以外 = 0)します。

import os

th = 128

labels = np.array([img_to_array(load_img(f"{os.path.splitext(f)[0]}.png", color_mode = 'grayscale')) for f in files])

labels[labels < th] = 0
labels[labels >= th] = 1

labels.shape

ラベルデータの形状は以下の通りです。

labels.shape 結果
(160, 288, 240, 1)

(3) CNN モデル

どのようなネットワーク構成が適しているのか分からなかったので、セマンティックセグメンテーション等で用いられている Encoder-Decoder の構成を参考にしてみました。

30x36 まで段階的に縮小して(Encoder)、元の大きさ 240x288 まで段階的に拡大する(Decoder)ようにしています。

最終層の活性化関数に sigmoid を使って 0 ~ 1 の値となるようにしています。

損失関数は binary_crossentropy を使うと進捗が遅そうに見えたので ※、代わりに mean_squared_error を使っています。

 ※ 今回の場合、輪郭(= 1)に該当するピクセルの方が少なくなるため
    binary_crossentropy を使用する場合は fit の class_weight 引数で
    調整する必要があったと思われます

また、参考のため mean_absolute_error(mae) の値も出力するように metrics で指定しています。

from keras.models import Model
from keras.layers import Input, Dropout
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

# Encoder

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

# Decoder

x = Conv2DTranspose(64, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(32, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(16, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

output = Conv2D(1, 1, activation = 'sigmoid')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])

model.summary()
model.summary() 結果
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 288, 240, 3)       0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 288, 240, 3)       12        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 288, 240, 16)      448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 144, 120, 16)      0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 144, 120, 16)      64        
_________________________________________________________________
dropout_1 (Dropout)          (None, 144, 120, 16)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 144, 120, 32)      4640      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 72, 60, 32)        0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 72, 60, 32)        128       
_________________________________________________________________
dropout_2 (Dropout)          (None, 72, 60, 32)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 72, 60, 64)        18496     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 36, 30, 64)        0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 36, 30, 64)        256       
_________________________________________________________________
dropout_3 (Dropout)          (None, 36, 30, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 36, 30, 128)       73856     
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 36, 30, 128)       147584    
_________________________________________________________________
batch_normalization_5 (Batch (None, 36, 30, 128)       512       
_________________________________________________________________
dropout_4 (Dropout)          (None, 36, 30, 128)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 72, 60, 64)        73792     
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
batch_normalization_6 (Batch (None, 72, 60, 64)        256       
_________________________________________________________________
dropout_5 (Dropout)          (None, 72, 60, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 144, 120, 32)      18464     
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
batch_normalization_7 (Batch (None, 144, 120, 32)      128       
_________________________________________________________________
dropout_6 (Dropout)          (None, 144, 120, 32)      0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 288, 240, 16)      4624      
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 288, 240, 16)      2320      
_________________________________________________________________
batch_normalization_8 (Batch (None, 288, 240, 16)      64        
_________________________________________________________________
dropout_7 (Dropout)          (None, 288, 240, 16)      0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 288, 240, 1)       17        
=================================================================
Total params: 576,541
Trainable params: 575,831
Non-trainable params: 710
_________________________________________________________________

(4) 学習

教師データが少ないため、全て学習で使う事にします。

実行例(441 ~ 480 エポック)
hist = model.fit(imgs, labels, initial_epoch = 440, epochs = 480, batch_size = 10)

Keras では fit を繰り返し呼び出すと学習を(続きから)再開できるので、40 エポックを何回か繰り返しました。(バッチサイズは 20 で始めて途中で 10 へ変えたりしています)

その場合、正しいエポックを出力するには initial_epochepochs の値を調整する必要があります ※。

 ※ initial_epoch を指定しなくても
    fit を繰り返し実行するだけで学習は継続されますが、
    その場合は出力されるエポックの値がクリアされます(1 からのカウントとなる)
結果例
Epoch 441/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
Epoch 442/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0125
Epoch 443/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
・・・
Epoch 478/480
160/160 [=====・・・ - loss: 0.0045 - mean_absolute_error: 0.0116
Epoch 479/480
160/160 [=====・・・ - loss: 0.0046 - mean_absolute_error: 0.0117
Epoch 480/480
160/160 [=====・・・ - loss: 0.0044 - mean_absolute_error: 0.0115

(5) 確認

fit の戻り値から mean_squared_error(loss)と mean_absolute_error の値の遷移をグラフ化してみます。

%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (8, 4)

plt.subplot(1, 2, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 2, 2)
plt.plot(hist.history['mean_absolute_error'])
結果例(441 ~ 480 エポック)

f:id:fits:20190114231037p:plain

(6) 検証

(a) 教師データ

教師データの入力画像と輪郭画像(ラベルデータ)、model.predict の結果を並べて表示してみます。

def predict(index, s = 6.0):
    plt.rcParams['figure.figsize'] = (s, s)

    sh = imgs.shape[1:-1]

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([imgs[index]]))[0]
    pred *= 255

    plt.subplot(1, 3, 1)
    # 入力画像の表示
    plt.imshow(imgs[index].astype(int))

    plt.subplot(1, 3, 2)
    # 輪郭画像(ラベルデータ)の表示
    plt.imshow(labels[index].reshape(sh), cmap = 'gray')

    plt.subplot(1, 3, 3)
    # predict の結果表示
    plt.imshow(pred.reshape(sh).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, 輪郭画像, model.predict 結果

f:id:fits:20190114231107j:plain

概ね教師データに近い結果が出るようになっています。

(b) 教師データ以外

教師データとして使っていない画像に対して model.predict を実施し、輪郭の検出を行ってみます。

def predict_eval(file, s = 4.0):
    plt.rcParams['figure.figsize'] = (s, s)

    img = img_to_array(load_img(file))

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    plt.subplot(1, 2, 1)
    # 入力画像の表示
    plt.imshow(img.astype(int))

    plt.subplot(1, 2, 2)
    # predict の結果表示
    plt.imshow(pred.reshape(pred.shape[:-1]).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, model.predict 結果

f:id:fits:20190114231154j:plain

所々で途切れたりしていますが、ある程度の輪郭は検出できているように見えます。

(7) 保存

学習したモデルを保存します。

model.save('model/c1_480.h5')

輪郭検出

学習済みモデルを使って輪郭検出を行う処理をスクリプト化してみました。

predict_contours.py
import sys
import os
import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
from keras.models import load_model
import cv2

model_file = sys.argv[1]
img_files = sys.argv[2]
dest_dir = sys.argv[3]

model = load_model(model_file)

for f in glob.glob(img_files):
    img = img_to_array(load_img(f))

    # 輪郭の検出
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    file, ext = os.path.splitext(os.path.basename(f))

    # 画像の保存
    cv2.imwrite(f"{dest_dir}/{file}_predict.png", pred)

    print(f"done: {f}")
実行例
python predict_contours.py model/c1_480.h5 img_eval2/*.jpg result

480 エポックの学習モデルを使って、教師データに無いタイプの背景を使った画像(影の影響もある)に試してみました。

輪郭検出結果例
入力画像 処理結果
f:id:fits:20181229154654j:plain f:id:fits:20190114231530p:plain
f:id:fits:20181229155049j:plain f:id:fits:20190114231547p:plain

こちらは難しかったようです。

なお、2つ目の画像は 120 エポックの学習モデルの方が良好な結果(輪郭がより多く検出されていた)でした。

MongoDB で条件に合致する子要素を抽出

MongoDB で指定の条件に合致する子要素のみを抽出する方法を調査してみました。

  • MongoDB 4.0.4

はじめに、下記 3つのドキュメントが sample コレクションへ登録されているとします。

ドキュメント内容
{ "_id" : 1, "items" : [
    { "color" : "black", "size" : "S" }, 
    { "color" : "white", "size" : "S" }
] }

{ "_id" : 2, "items" : [
    { "color" : "red",   "size" : "L" }, 
    { "color" : "blue",  "size" : "S" }
] }

{ "_id" : 3, "items" : [
    { "color" : "white", "size" : "L" }, 
    { "color" : "red",   "size" : "L" }, 
    { "color" : "white", "size" : "S" }
] }

ここで、items.colorwhite のものだけを抽出し、以下の結果を得る事を目指してみます。(items の中身が white のものだけを含むようにする)

目標とする検索結果
{ "_id" : 1, "items" : [
    { "color" : "white", "size" : "S" }
] }

{ "_id" : 3, "items" : [
    { "color" : "white", "size" : "L" }, 
    { "color" : "white", "size" : "S" }
] }

(a) items.color で条件指定

まずは {"items.color": "white"} の条件で find した結果です。

white を持つドキュメントだけを抽出できましたが、ドキュメントの内容はそのままなので black 等の余計なものも含んでしまいます。

> db.sample.find({"items.color": "white"})

{ "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

(b) $elemMatch 使用

次に $elemMatch を使ってみます。

$elemMatch を find の query(第一引数)で使うか、projection(第二引数)で使うかで結果が変わります。

query で使う場合は先程の (a) と同じ結果になります。

query で使用
> db.sample.find({"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

query の条件は指定せずに projection で $elemMatch を使った場合の結果は以下です。

全ドキュメントを対象に items の中身がフィルタリングされていますが、条件に合致する全ての子要素が抽出されるわけでは無く、(条件に合致する)先頭の要素しか含まれていません。

projection で使用1
> db.sample.find({}, {"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 2 }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }

query 条件を指定する事で不要なドキュメントを除く事はできますが、条件に合致する全ての子要素が抽出されない事に変わりはありません。

projection で使用2
> db.sample.find({"items.color": "white"}, {"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }

このように、今回のようなドキュメントに対して $elemMatch を使うと、条件に合致する最初の子要素だけが抽出されるようです。(何らかの回避策があるのかもしれませんが)

(c) aggregate 使用

最後に aggregate を使ってみます。

$unwind を使うと配列の個々の要素を処理できるので、$match で white のみに限定した後、$group でグルーピングすれば良さそうです。

> db.sample.aggregate([
  {$unwind: "$items"}, 
  {$match: {"items.color": "white"}}, 
  {$group: {_id: "$_id", "items": {$push: "$items"}}}
])

{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }

これで目指した結果は一応得られました。

なお、対象を最初に絞り込むようにしてソートを付けると以下のようになります。

> db.sample.aggregate([
  {$match: {"items.color": "white"}},
  {$unwind: "$items"},
  {$match: {"items.color": "white"}},
  {$group: {_id: "$_id", "items": {$push: "$items"}}},
  {$sort:  {_id: 1}}
])

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

Scala のケースクラスに制約を持たせる

Scala のケースクラスで値に制約を持たせたい場合にどうするか。

例えば、以下のケースクラスで amount の値を 0 以上となるように制限し、0 未満ならインスタンス化を失敗させる事を考えてみます。

case class Quantity(amount: Int)

使用した環境は以下

ソースは http://github.com/fits/try_samples/tree/master/blog/20181009/

ケースクラスの値を制限

まず、最も単純なのは以下のような実装だと思います。

case class Quantity(amount: Int) {
  if (amount < 0)
    throw new IllegalArgumentException(s"amount($amount) < 0")
}

これだと、例外が throw されてしまい関数プログラミングで扱い難いので Try[Quantity]Option[Quantity] 等を返すようにしたいところです。

そこで、以下のようにケースクラスを abstract 化して、コンパニオンオブジェクトへ生成関数を定義する方法を使ってみました。

sample.scala
import scala.util.{Try, Success, Failure}

sealed abstract case class Quantity private (amount: Int)

object Quantity {
  def apply(amount: Int): Try[Quantity] =
    if (amount >= 0)
      Success(new Quantity(amount){})
    else
      Failure(new IllegalArgumentException(s"amount($amount) < 0"))
}

println(Quantity(1))
println(Quantity(0))
println(Quantity(-1))

// この方法では Quantity へ copy がデフォルト定義されないため
// copy は使えません(error: value copy is not a member of this.Quantity)
//
// println(Quantity(1).map(_.copy(-1)))

実行結果は以下の通りです。

実行結果
> scala sample.scala

Success(Quantity(1))
Success(Quantity(0))
Failure(java.lang.IllegalArgumentException: amount(-1) < 0)

上記 sample.scala では、以下を直接呼び出せないようにしてケースクラスの勝手なインスタンス化を防止しています。

  • (a) コンストラクタ(new)
  • (b) コンパニオンオブジェクトへデフォルト定義される apply
  • (c) ケースクラスへデフォルト定義される copy

そのために、下記 2点を実施しています。

  • (1) コンストラクタの private 化 : (a) の防止
  • (2) ケースクラスの abstract 化 : (b) (c) の防止

(1) コンストラクタの private 化

以下のように private を付ける事でコンストラクタを private 化できます。

コンストラクタの private 化
case class Quantity private (amount: Int)

これで (a) new Quantity(・・・) の実行を防止できますが、以下のように (b) の apply や (c) の copy を実行できてしまいます。

検証例
scala> case class Quantity private (amount: Int)
defined class Quantity

scala> new Quantity(1)
<console>:14: error: constructor Quantity in class Quantity cannot be accessed in object $iw
       new Quantity(1)

scala> Quantity(1)
res1: Quantity = Quantity(1)

scala> Quantity.apply(2)
res2: Quantity = Quantity(2)

scala> Quantity(3).copy(30)
res3: Quantity = Quantity(30)

(2) ケースクラスの abstract 化

ケースクラスを abstract 化すると、通常ならデフォルト定義されるコンパニオンオブジェクトの apply やケースクラスの copy を防止できるようです。

そのため、(1) と組み合わせることで (a) ~ (c) を防止できます。

ケースクラスの abstract 化とコンストラクタの private 化
sealed abstract case class Quantity private (amount: Int)

以下のように Quantity.apply は定義されなくなります。

検証例
scala> sealed abstract case class Quantity private (amount: Int)
defined class Quantity

scala> new Quantity(1){}
<console>:14: error: constructor Quantity in class Quantity cannot be accessed in <$anon: Quantity>
       new Quantity(1){}
           ^

scala> Quantity.apply(1)
<console>:14: error: value apply is not a member of object Quantity
       Quantity.apply(1)

このままだと何もできなくなるため、実際はコンパニオンオブジェクトへ生成用の関数が必要になります。

sealed abstract case class Quantity private (amount: Int)

object Quantity {
  def create(amount: Int): Quantity = new Quantity(amount){}
}

備考

今回の方法は、以下の書籍に記載されているような ADTs(algebraic data types)と Smart constructors をより安全に定義するために活用できると考えています。

Functional and Reactive Domain Modeling

Functional and Reactive Domain Modeling

Kotlin の関数型プログラミング用ライブラリ Λrrow を試してみる

Kotlin で ScalaScalazCats のような関数型プログラミング用のライブラリを探していたところ、以下を見つけたので試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180822/

はじめに

Λrrow は以下のような要素で構成されており、Haskell の型クラス・インスタンスのような仕組みを実現しているようです。

  • Datatypes (Option, Either, Try, Kleisli, StateT, WriterT 等)
  • Typeclasses (Functor, Applicative, Monad, Monoid, Eq, Show 等)
  • Instances (OptionMonadInstance 等)

そして、その実現にはアノテーションプロセッサが活用されているようです。(@higherkind@instance から補助的なコードを生成)

Option のソースコード

例えば、Optionソースコードは以下のようになっていますが、OptionOf の定義は github のソース上には見当たりません。

arrow/core/Option.kt
@higherkind
sealed class Option<out A> : OptionOf<A> {
    ・・・
}

OptionOf はアノテーションプロセッサ ※ で生成したコード(以下)で定義されています。(OptionOf<A>Kind<ForOption, A> の型エイリアス

higherkind.arrow.core.Option.kt
package arrow.core

class ForOption private constructor() { companion object }
typealias OptionOf<A> = arrow.Kind<ForOption, A>

@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <A> OptionOf<A>.fix(): Option<A> =
  this as Option<A>
 ※ @higherkind に対する処理は
    arrow.higherkinds.HigherKindsFileGenerator に実装されているようです

ここで、ForOption というクラスが定義されていますが、この For + 型名 はコンテナ型を表現するための型のようです。

ForOption がある事で Kind<Option<A>, A> ではなく Kind<ForOption, A> と定義できるようになっており、Functor<F>Monad<F> における F の具体型として使われています。(例. Monad<ForOption>

Either のソースコード

次に Eitherソースコードも見てみます。

Either のように型パラメータが複数の場合はアノテーションプロセッサで生成されるコードの内容に多少の違いがあるようです。

arrow/core/Either.kt
@higherkind
sealed class Either<out A, out B> : EitherOf<A, B> {
    ・・・
}

アノテーションプロセッサで生成されたコードは以下のように EitherOf の他に EitherPartialOf が型エイリアスとして定義されています。

higherkind.arrow.core.Either.kt
package arrow.core

class ForEither private constructor() { companion object }
typealias EitherOf<A, B> = arrow.Kind2<ForEither, A, B>
typealias EitherPartialOf<A> = arrow.Kind<ForEither, A>

@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <A, B> EitherOf<A, B>.fix(): Either<A, B> =
  this as Either<A, B>

また、上記で定義されている ForEither クラスとは別に arrow-instances-core モジュールには ForEither 関数 も用意されています。

arrow/instances/either.kt
・・・
class EitherContext<L> : EitherMonadErrorInstance<L>, EitherTraverseInstance<L>, EitherSemigroupKInstance<L> {
  override fun <A, B> Kind<EitherPartialOf<L>, A>.map(f: (A) -> B): Either<L, B> =
    fix().map(f)
}

class EitherContextPartiallyApplied<L> {
  infix fun <A> extensions(f: EitherContext<L>.() -> A): A =
    f(EitherContext())
}

fun <L> ForEither(): EitherContextPartiallyApplied<L> =
  EitherContextPartiallyApplied()

この ForEither 関数は extensions を呼び出す際に使用する事になります。

サンプルコード

Option と Either をそれぞれ使ったサンプルを作成してみます。

(a) Option の利用

Option の SomeNone の作成にはいくつかの方法が用意されています。

Monad の拡張関数として定義されている binding を使うと Haskell の do 記法や Scala の for 式のようにモナドを処理できるようです。(Kotlin の coroutines 機能で実現)

binding の呼び出し方はいくつか考えられますが、ForOption extensions { ・・・ } を使うのが基本のようです。

ForOption extensions へ渡す処理内では this が OptionContext となるため、OptionContext の処理を呼び出せるようになっています。

また、binding へ渡す処理内では MonadContinuation<ForOption, *> が this となります。

bind は MonadContinuation 内で suspend fun <B> Kind<F, B>.bind(): B と定義されており、Kind<F, B> から B の値を取り出す処理となっています。

src/main/kotlin/App.kt
import arrow.core.*
import arrow.instances.*
import arrow.typeclasses.binding

fun main(args: Array<String>) {
    // Some
    val d1: Option<Int> = Option.just(10)
    val d2: Option<Int> = 5.some()
    val d3: Option<Int> = Some(2)
    // None
    val d4: Option<Int> = Option.empty()
    val d5: Option<Int> = none()
    val d6: Option<Int> = None

    // Some(15)
    val r1 = d1.flatMap { a ->
        d2.map { b -> a + b }
    }

    println(r1)

    // Some(17)
    val r2 = Option.monad().binding { // this: MonadContinuation<ForOption, *>
        val a = d1.bind() // 10
        val b = d2.bind() //  5
        val c = d3.bind() //  2
        a + b + c
    }

    println(r2)

    // Some(17)
    val r3 = ForOption extensions { // this: OptionContext
        println(this) // arrow.instances.OptionContext@3ffc5af1

        binding { // this: MonadContinuation<ForOption, *>
            println(this) // arrow.typeclasses.MonadContinuation@26653222

            val a = d1.bind()
            val b = d2.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r3)

    // Some(17)
    val r4 = OptionContext.binding {
        val a = d1.bind()
        val b = d2.bind()
        val c = d3.bind()
        a + b + c
    }

    println(r4)

    // None
    val r5 = Option.monad().binding {
        val a = d1.bind()
        val b = d4.bind()
        a + b
    }

    println(r5)

    // None
    val r6 = ForOption extensions {
        binding {
            val a = d5.bind()
            val b = d2.bind()
            val c = d6.bind()
            a + b + c
        }
    }

    println(r6)
}

Option のような基本的な型を使うだけであれば、依存ライブラリとして arrow-instances-core を指定するだけで問題なさそうです。

そのため、Gradle ビルド定義ファイルは下記のような内容になります。

ここでは The feature "coroutines" is experimental (see: https://kotlinlang.org/docs/diagnostics/experimental-coroutines) 警告ログを出力しないように coroutines を有効化しています。

build.gradle (Gradle ビルド定義)
plugins {
    id 'org.jetbrains.kotlin.jvm' version '1.2.51'
    id 'application'
}

mainClassName = 'AppKt' // 実行するクラス名

repositories {
    jcenter()
}

dependencies {
    compile 'org.jetbrains.kotlin:kotlin-stdlib-jdk8'
    compile 'io.arrow-kt:arrow-instances-core:0.7.3'
}

// coroutines の有効化(ビルド時の警告を抑制)
kotlin {
    experimental {
        coroutines 'enable'
    }
}

実行結果は以下の通りです。

実行結果
> gradle run

・・・
Some(15)
Some(17)
arrow.instances.OptionContext@3ffc5af1
arrow.typeclasses.MonadContinuation@26653222
Some(17)
Some(17)
None
None

(b) Either の利用

Either の場合、extensions の呼び出し方が Option とは多少異なります。

Option の場合は ForOption の extensions を呼び出しましたが、Either の場合は ForEither に extensions は用意されておらず、代わりに EitherContextPartiallyApplied<L> クラス内に定義されています。

EitherContextPartiallyApplied オブジェクトは ForEither 関数 で取得できるので、これを使って ForEither<Left 側の型>() extensions { ・・・ } のようにします。

ForEither<String>() extensions へ渡す処理内の this は EitherContext<String> で、binding へ渡す処理内の this は MonadContinuation<EitherPartialOf<String>, *> となります。

src/main/kotlin/App.kt
import arrow.core.*
import arrow.instances.*
import arrow.typeclasses.binding

fun main(args: Array<String>) {
    // Right
    val d1: Either<String, Int> = Either.right(10)
    val d2: Either<String, Int> = 5.right()
    val d3: Either<String, Int> = Right(2)
    // Left
    val d4: Either<String, Int> = Either.left("error data")

    // Right(b=15)
    val r1 = d1.flatMap { a ->
        d2.map { b -> a + b }
    }

    println(r1)

    // Right(b=17)
    val r2 = Either.monad<String>().binding { // this: MonadContinuation<EitherPartialOf<String>, *>
        val a = d1.bind()
        val b = d2.bind()
        val c = d3.bind()
        a + b + c
    }

    println(r2)

    // Right(b=17)
    // ForEither<String>() 関数の呼び出し
    val r3 = ForEither<String>() extensions { // this: EitherContext<String>
        binding { // this: MonadContinuation<EitherPartialOf<String>, *>
            val a = d1.bind()
            val b = d2.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r3)

    // Left(a=error data)
    val r4 = ForEither<String>() extensions {
        binding {
            val a = d1.bind()
            val b = d4.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r4)
}

(a) と同じ内容の build.gradle を使って実行します。

実行結果
> gradle run

・・・
Right(b=15)
Right(b=17)
Right(b=17)
Left(a=error data)

TypeScript で funfix を使用 - tsc, FuseBox

funfixJavaScript, TypeScript, Flow の関数型プログラミング用ライブラリで、Fantasy LandStatic Land ※ に準拠し Scala の Option, Either, Try, Future 等と同等の型が用意されているようです。

 ※ JavaScript 用に Monoid や Monad 等の代数的構造に関する仕様を定義したもの

今回は Option を使った単純な処理を TypeScript で実装し Node.js で実行してみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180730/

はじめに

Option を使った下記サンプルをコンパイルして実行します。

サンプルソース
import { Option, Some } from 'funfix'

const f = (ma, mb) => ma.flatMap(a => mb.map(b => `${a} + ${b} = ${a + b}`))

const d1 = Some(10)
const d2 = Some(2)

console.log( d1 )

console.log('-----')

console.log( f(d1, d2) )
console.log( f(d1, Option.none()) )

console.log('-----')

console.log( f(d1, d2).getOrElse('none') )
console.log( f(d1, Option.none()).getOrElse('none') )

ビルドと実行

上記ソースファイルを以下の 2通りでビルドして実行してみます。

  • (a) tsc 利用
  • (b) FuseBox 利用

(a) tsc を利用する場合

tsc コマンドを使って TypeScript のソースをコンパイルします。

まずは typescript と funfix モジュールをそれぞれインストールします。

typescript インストール
> npm install --save-dev typescript
funfix インストール
> npm install --save funfix

この状態で sample.ts ファイルをコンパイルしてみると、型関係のエラーが出るものの sample.js は正常に作られました。

コンパイル1
> tsc sample.ts

node_modules/funfix-core/dist/disjunctions.d.ts:775:14 - error TS2416: Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'.
  Type 'undefined' is not assignable to type 'never'.

775     readonly value: undefined;
                 ~~~~~


node_modules/funfix-effect/dist/eval.d.ts:256:42 - error TS2304: Cannot find name 'Iterable'.

256     static sequence<A>(list: Eval<A>[] | Iterable<Eval<A>>): Eval<A[]>;
                                             ~~~~~~~~
・・・

sample.js を実行してみると特に問題無く動作します。

実行1
> node sample.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

これで一応は動いたわけですが、コンパイル時にエラーが出るというのも望ましい状態ではないので、エラー部分を解決してみます。

他にも方法があるかもしれませんが、ここでは以下のように対応します。

  • (1) Property 'value' in type 'TNone' ・・・ 'Option<never>' のエラーに対して tsc 実行時に --strictNullChecks オプションを指定して対応
  • (2) Cannot find name 'Iterable' 等のエラーに対して @types/node をインストールして対応

strictNullChecks は tsc の実行時オプションで指定する以外にも設定ファイル tsconfig.json で設定する事もできるので、ここでは tsconfig.json ファイルを使います。

(1) tsconfig.json
{
  "compilerOptions": {
    "strictNullChecks": true
  }
}

次に @types/node をインストールします。

@types/node には Node.js で実行するための型定義(Node.js 依存の API 等)が TypeScript 用に定義されています。

(2) @types/node インストール
> npm install --save-dev @types/node

この状態で tsc を実行すると先程のようなエラーは出なくなりました。(tsconfig.json を適用するため tsc コマンドへ引数を指定していない点に注意)

コンパイル2
> tsc

実行結果にも差異はありません。

実行2
> node sample.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

最終的な package.json の内容は以下の通りです。

package.json
{
  "name": "sample",
  "version": "1.0.0",
  "devDependencies": {
    "@types/node": "^10.5.4",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

(b) FuseBox を利用する場合

次に、モジュールバンドラーの FuseBox を使用してみます。(以降は (a) とは異なるディレクトリで実施)

なお、ここでは npm の代わりに yarn を使っていますが、npm でも特に問題はありません。

yarn のインストール例(npm 使用)
> npm install -g yarn

(b-1) 型チェック無し

まずは typescript, fuse-box, funfix をそれぞれインストールしておきます。

typescript と fuse-box インストール
> yarn add --dev typescript fuse-box
funfix インストール
> yarn add funfix

FuseBox ではビルド定義を JavaScript のコードで記載します。 とりあえずは必要最小限の設定を行いました。

bundle で指定した名称が init の $name に適用されるため、*.tsコンパイル結果と依存モジュールの内容をバンドルして bundle.js へ出力する事になります。

なお、> でロード時に実行する(コードを記載した)ファイルを指定します。

fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box')

const fuse = FuseBox.init({
    output: '$name.js'
})

fuse.bundle('bundle').instructions('> *.ts')

fuse.run()

上記スクリプトを実行してビルド(TypeScript のコンパイルとバンドル)を行います。

ビルド
> node fuse.js

--- FuseBox 3.4.0 ---
  → Generating recommended tsconfig.json:  ・・・\sample_fusebox1\tsconfig.json
  → Typescript script target: ES7

--------------------------
Bundle "bundle"

    sample.js
└──  (1 files,  700 Bytes) default
└── funfix-core 34.4 kB (1 files)
└── funfix-effect 43.1 kB (1 files)
└── funfix-exec 79.5 kB (1 files)
└── funfix 1 kB (1 files)
size: 158.7 kB in 765ms

初回実行時にデフォルト設定の tsconfig.json が作られました。(tsconfig.json が存在しない場合)

tsc の時のような型関係のエラーは出ていませんが、これは FuseBox がデフォルトで TypeScript の型チェックをしていない事が原因のようです。

型チェックを実施するには fuse-box-typechecker プラグインを使う必要がありそうです。

実行
> node bundle.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

package.json の内容は以下の通りです。

package.json
{
  "name": "sample_fusebox1",
  "version": "1.0.0",
  "main": "bundle.js",
  "license": "MIT",
  "devDependencies": {
    "fuse-box": "^3.4.0",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

(b-2) 型チェック有り

TypeScript の型チェックを行うようにしてみます。

まずは、(b-1) と同じ構成に fuse-box-typechecker プラグインを加えます。

fuse-box-typechecker を追加インストール
> yarn add --dev fuse-box-typechecker

次に、fuse.js へ fuse-box-typechecker プラグインの設定を追加します。

TypeChecker で型チェックにエラーがあった場合、例外が throw されるようにはなっていないため、ここではエラーがあった場合に Error を throw して fuse.run() を実行しないようにしてみました。

ただし、こうすると tsconfig.json を予め用意しておく必要があります。(TypeChecker に tsconfig.json が必要)

fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box')
const {TypeChecker} = require('fuse-box-typechecker')

// fuse-box-typechecker の設定
const testSync = TypeChecker({
    tsConfig: './tsconfig.json'
})

const fuse = FuseBox.init({
    output: '$name.js'
})

fuse.bundle('bundle').instructions('> *.ts')

testSync.runPromise()
    .then(n => {
        if (n != 0) {
            // 型チェックでエラーがあった場合
            throw new Error(n)
        }
        // 型チェックでエラーがなかった場合
        return fuse.run()
    })
    .catch(console.error)

これで、ビルド時に (a) と同様の型エラーが出るようになりました。

ビルド1
> node fuse.js

・・・
--- FuseBox 3.4.0 ---

Typechecker plugin(promisesync) .
Time:Sun Jul 29 2018 12:40:47 GMT+0900 (GMT+09:00)

File errors:
└── .\node_modules\funfix-core\dist\disjunctions.d.ts
   | ・・・\sample_fusebox2\node_modules\funfix-core\dist\disjunctions.d.ts (775,14) (Error:TS2416) Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'.
  Type 'undefined' is not assignable to type 'never'.

Errors:1
└── Options: 0
└── Global: 0
└── Syntactic: 0
└── Semantic: 1
└── TsLint: 0

Typechecking time: 4116ms
Quitting typechecker

・・・

ここで、Iterable の型エラーが出ていないのは fuse-box-typechecker のインストール時に @types/node もインストールされているためです。

(a) と同様に strictNullChecks の設定を tsconfig.json追記して、このエラーを解決します。

tsconfig.json へ strictNullChecks の設定を追加
{
  "compilerOptions": {
    "module": "commonjs",
    "target": "ES7",
    ・・・
    "strictNullChecks": true
  }
}

これでビルドが成功するようになりました。

ビルド2
> node fuse.js

・・・
Typechecker name: undefined
Typechecker basepath: ・・・\sample_fusebox2
Typechecker tsconfig: ・・・\sample_fusebox2\tsconfig.json
--- FuseBox 3.4.0 ---

Typechecker plugin(promisesync) .
Time:Sun Jul 29 2018 12:44:57 GMT+0900 (GMT+09:00)
All good, no errors :-)
Typechecking time: 4103ms
Quitting typechecker

killing worker  → Typescript config file:  \tsconfig.json
  → Typescript script target: ES7

--------------------------
Bundle "bundle"

    sample.js
└──  (1 files,  700 Bytes) default
└── funfix-core 34.4 kB (1 files)
└── funfix-effect 43.1 kB (1 files)
└── funfix-exec 79.5 kB (1 files)
└── funfix 1 kB (1 files)
size: 158.7 kB in 664ms
実行結果
> node bundle.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

package.json の内容は以下の通りです。

package.json
{
  "name": "sample_fusebox2",
  "version": "1.0.0",
  "main": "index.js",
  "license": "MIT",
  "devDependencies": {
    "fuse-box": "^3.4.0",
    "fuse-box-typechecker": "^2.10.0",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

Word2Vec を用いた併売の分析 - gensim

トピックモデルを用いた併売の分析」ではトピックモデルによる併売の分析を試しましたが、今回は gensim の Word2Vec で試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180617/

はじめに

データセット

これまで は適当に作ったデータセットを使っていましたが、今回は R の Groceries データセット ※ をスペース区切りのテキストファイル(groceries.txt)にして使います。(商品名にスペースを含む場合は代わりに _ を使っています)

 ※ ある食料雑貨店における 30日間の POS データ
groceries.txt
citrus_fruit semi-finished_bread margarine ready_soups
tropical_fruit yogurt coffee
whole_milk
pip_fruit yogurt cream_cheese_ meat_spreads
other_vegetables whole_milk condensed_milk long_life_bakery_product
whole_milk butter yogurt rice abrasive_cleaner
rolls/buns
other_vegetables UHT-milk rolls/buns bottled_beer liquor_(appetizer)
pot_plants
whole_milk cereals
・・・
cooking_chocolate
chicken citrus_fruit other_vegetables butter yogurt frozen_dessert domestic_eggs rolls/buns rum cling_film/bags
semi-finished_bread bottled_water soda bottled_beer
chicken tropical_fruit other_vegetables vinegar shopping_bags

R によるアソシエーションルールの抽出結果

参考のため、まずは R を使って Groceries データセットapriori で処理しました。

リフト値を優先するため、支持度 (supp) と確信度 (conf) を低めの値にしています。

groceries_apriori.R
library(arules)
data(Groceries)

params <- list(supp = 0.001, conf = 0.1)

rules <- apriori(Groceries, parameter = params)

inspect(head(sort(rules, by = "lift"), 10))
実行結果
> Rscript groceries_apriori.R

・・・
     lhs                        rhs                         support confidence     lift count
[1]  {bottled beer,                                                                          
      red/blush wine}        => {liquor}                0.001931876  0.3958333 35.71579    19
[2]  {hamburger meat,                                                                        
      soda}                  => {Instant food products} 0.001220132  0.2105263 26.20919    12
[3]  {ham,                                                                                   
      white bread}           => {processed cheese}      0.001931876  0.3800000 22.92822    19
[4]  {root vegetables,                                                                       
      other vegetables,                                                                      
      whole milk,                                                                            
      yogurt}                => {rice}                  0.001321810  0.1688312 22.13939    13
[5]  {bottled beer,                                                                          
      liquor}                => {red/blush wine}        0.001931876  0.4130435 21.49356    19
[6]  {Instant food products,                                                                 
      soda}                  => {hamburger meat}        0.001220132  0.6315789 18.99565    12
[7]  {curd,                                                                                  
      sugar}                 => {flour}                 0.001118454  0.3235294 18.60767    11
[8]  {soda,                                                                                  
      salty snack}           => {popcorn}               0.001220132  0.1304348 18.06797    12
[9]  {sugar,                                                                                 
      baking powder}         => {flour}                 0.001016777  0.3125000 17.97332    10
[10] {processed cheese,                                                                      
      white bread}           => {ham}                   0.001931876  0.4634146 17.80345    19

bottled beer と red/blush wine で liquor が同時に買われやすい、hamburger meat と soda で Instant food products が同時に買われやすいという結果(アソシエーションルール)が出ています。

Word2Vec の適用

それでは groceries.txt を gensim の Word2Vec で処理してみます。 とりあえず iter を 500 に min_count を 1 にしてみました。

なお、購入品目の多い POS データを処理する場合は window パラメータを大きめにすべきかもしれません。(今回はデフォルト値の 5)

今回は Jupyter Notebook で実行しています。

Word2Vec モデルの構築
from gensim.models import word2vec

sentences = word2vec.LineSentence('groceries.txt')

model = word2vec.Word2Vec(sentences, iter = 500, min_count = 1)

類似品の算出

まずは、wv.most_similar で類似単語(商品)を抽出してみます。

pork の類似単語
model.wv.most_similar('pork')
[('turkey', 0.5547687411308289),
 ('ham', 0.49448296427726746),
 ('pip_fruit', 0.46879759430885315),
 ('tropical_fruit', 0.4383287727832794),
 ('butter', 0.43373265862464905),
 ('frankfurter', 0.4334157109260559),
 ('root_vegetables', 0.4249211549758911),
 ('citrus_fruit', 0.4246293306350708),
 ('chicken', 0.42378148436546326),
 ('sausage', 0.41153857111930847)]

微妙なものも含んでいますが、それなりの結果になっているような気もします。

most_similar はベクトル的に類似している単語を抽出するため、POS データを処理する場合は競合や代用品の抽出に使えるのではないかと思います。

併売の分析

併売の商品はお互いに類似していないと思うので most_similar は役立ちそうにありませんが、それでも何らかの関係性はありそうな気がします。

そこで、指定した単語群の中心となる単語を抽出する predict_output_word を使えないかと思い、R で抽出したアソシエーションルールの組み合わせで試してみました。

predict_output_word の検証

bottled_beer と red/blush_wine
model.predict_output_word(['bottled_beer', 'red/blush_wine'])
[('liquor', 0.22384332),
 ('prosecco', 0.04933687),
 ('sparkling_wine', 0.0345262),
 ・・・]

R の結果に出てた liquor が先頭(確率が最大)に来ています。

bottled_beer と red/blush_wine
model.predict_output_word(['hamburger_meat', 'soda'])
[('Instant_food_products', 0.054281656),
 ('canned_vegetables', 0.029985178),
 ('pasta', 0.025487985),
 ・・・]

ここでも R の結果に出てた Instant_food_products が先頭に来ています。

ham と white_bread
model.predict_output_word(['ham', 'white_bread'])
[('processed_cheese', 0.20990367),
 ('sweet_spreads', 0.024131883),
 ('spread_cheese', 0.023222428),
 ・・・]

こちらも同様です。

root_vegetables と other_vegetables と whole_milk と yogurt
model.predict_output_word(['root_vegetables', 'other_vegetables', 'whole_milk', 'yogurt'])
[('herbs', 0.024541182),
 ('liver_loaf', 0.019327056),
 ('turkey', 0.01775743),
 ('onions', 0.01760579),
 ('specialty_cheese', 0.014991459),
 ('packaged_fruit/vegetables', 0.014529809),
 ('spread_cheese', 0.012931713),
 ('meat', 0.012434797),
 ('beef', 0.011924307),
 ('butter_milk', 0.011828974)]

R の結果にあった rice はこの中には含まれていません。

curd と sugar
model.predict_output_word(['curd', 'sugar'])
[('flour', 0.076272935),
 ('pudding_powder', 0.055790607),
 ('baking_powder', 0.026003197),
 ・・・]

R の結果に出てた flour (小麦粉) が先頭に来ています。

soda と salty_snack
model.predict_output_word(['soda', 'salty_snack'])
[('popcorn', 0.05830234),
 ('nut_snack', 0.046429735),
 ('chewing_gum', 0.0213278),
 ・・・]

こちらも同様です。

sugar と baking_powder
model.predict_output_word(['sugar', 'baking_powder'])
[('flour', 0.11954326),
 ('cooking_chocolate', 0.046284538),
 ('pudding_powder', 0.03714784),
 ・・・]

こちらも同様です。

以上のように、少なくとも 2品を指定した場合の predict_output_word の結果は R で抽出したアソシエーションルールに合致しているようです。

Word2Vec のパラメータに左右されるのかもしれませんが、この結果を見る限りでは predict_output_word を 3品の併売の組み合わせ抽出に使えるかもしれない事が分かりました。

3品の併売

次に predict_output_word で 2品に対する 1品を確率の高い順に抽出してみました。

なお、ここでは 3品の組み合わせの購入数が 10 未満のものは除外するようにしています。

from collections import Counter
import itertools

# 3品の組み合わせのカウント
tri_counter = Counter([c for ws in sentences for c in itertools.combinations(sorted(ws), 3)])

# 2品の組み合わせを作成
pairs = itertools.combinations(model.wv.vocab.keys(), 2)

sorted([
    (p, item, prob) for p in pairs for item, prob in model.predict_output_word(p)
    if prob >= 0.05 and tri_counter[tuple(sorted([p[0], p[1], item]))] >= 10
], key = lambda x: -x[2])
[(('bottled_beer', 'red/blush_wine'), 'liquor', 0.22384332),
 (('white_bread', 'ham'), 'processed_cheese', 0.20990367),
 (('bottled_beer', 'liquor'), 'red/blush_wine', 0.16274776),
 (('sugar', 'baking_powder'), 'flour', 0.11954326),
 (('curd', 'sugar'), 'flour', 0.076272935),
 (('margarine', 'sugar'), 'flour', 0.07422828),
 (('flour', 'sugar'), 'baking_powder', 0.07345509),
 (('sugar', 'whipped/sour_cream'), 'flour', 0.072731614),
 (('rolls/buns', 'hamburger_meat'), 'Instant_food_products', 0.06818052),
 (('sugar', 'root_vegetables'), 'flour', 0.0641469),
 (('tropical_fruit', 'white_bread'), 'processed_cheese', 0.061861355),
 (('soda', 'ham'), 'processed_cheese', 0.06138085),
 (('white_bread', 'processed_cheese'), 'ham', 0.061199907),
 (('whole_milk', 'ham'), 'processed_cheese', 0.059773713),
 (('beef', 'root_vegetables'), 'herbs', 0.059243686),
 (('sugar', 'whipped/sour_cream'), 'baking_powder', 0.05871357),
 (('soda', 'salty_snack'), 'popcorn', 0.05830234),
 (('soda', 'popcorn'), 'salty_snack', 0.05819882),
 (('red/blush_wine', 'liquor'), 'bottled_beer', 0.057226427),
 (('flour', 'baking_powder'), 'sugar', 0.05517209),
 (('soda', 'hamburger_meat'), 'Instant_food_products', 0.054281656),
 (('processed_cheese', 'ham'), 'white_bread', 0.053193364),
 (('other_vegetables', 'ham'), 'processed_cheese', 0.052585844)]

R で抽出したアソシエーションルールと同じ様な結果が出ており、それなりの結果が出ているように思います。

skip-gram の場合

gensim の Word2Vec はデフォルトで CBoW を使うようですので、skip-gram の場合にどうなるかも簡単に確認してみました。

skip-gram の使用
model = word2vec.Word2Vec(sentences, iter = 500, min_count = 1, sg = 1)

まずは predict_output_word の結果をいくつか見てみます。

先頭(確率が最大のもの)は変わらないようですが、CBoW よりも確率の値が全体的に低くなっているようです。

bottled_beer と red/blush_wine
model.predict_output_word(['bottled_beer', 'red/blush_wine'])
[('liquor', 0.076620705),
 ('prosecco', 0.030791236),
 ('liquor_(appetizer)', 0.027123762),
 ・・・]
hamburger_meat と soda
model.predict_output_word(['hamburger_meat', 'soda'])
[('Instant_food_products', 0.022627866),
 ('pasta', 0.018009944),
 ('canned_vegetables', 0.01685342),
 ・・・]
root_vegetables と other_vegetables と whole_milk と yogurt
model.predict_output_word(['root_vegetables', 'other_vegetables', 'whole_milk', 'yogurt'])
[('herbs', 0.015105391),
 ('turkey', 0.014365919),
 ('rice', 0.01316431),
 ・・・]

ここでは、CBoW で 10番以内に入っていなかった rice が入っています。

次に、先程と同様に predict_output_word で 3品の組み合わせを確率順に抽出してみます。

確率の値が全体的に下がっているため、最小値の条件を 0.02 へ変えています。

predict_output_word を使った 3品の組み合わせ抽出
・・・
sorted([
    (p, item, prob) for p in pairs for item, prob in model.predict_output_word(p)
    if prob >= 0.02 and tri_counter[tuple(sorted([p[0], p[1], item]))] >= 10
], key = lambda x: -x[2])
[(('bottled_beer', 'red/blush_wine'), 'liquor', 0.076620705),
 (('bottled_beer', 'liquor'), 'red/blush_wine', 0.0712179),
 (('white_bread', 'ham'), 'processed_cheese', 0.039820198),
 (('red/blush_wine', 'liquor'), 'bottled_beer', 0.031292748),
 (('sugar', 'baking_powder'), 'flour', 0.030803043),
 (('sugar', 'whipped/sour_cream'), 'flour', 0.029322423),
 (('margarine', 'sugar'), 'flour', 0.027827),
 (('beef', 'root_vegetables'), 'herbs', 0.02740662),
 (('curd', 'sugar'), 'flour', 0.025570681),
 (('flour', 'sugar'), 'baking_powder', 0.025403246),
 (('tropical_fruit', 'root_vegetables'), 'turkey', 0.025329975),
 (('whole_milk', 'ham'), 'processed_cheese', 0.024535457),
 (('rolls/buns', 'hamburger_meat'), 'Instant_food_products', 0.02427808),
 (('flour', 'baking_powder'), 'sugar', 0.023779714),
 (('tropical_fruit', 'white_bread'), 'processed_cheese', 0.023528077),
 (('sugar', 'root_vegetables'), 'flour', 0.023394365),
 (('soda', 'salty_snack'), 'popcorn', 0.02322538),
 (('whole_milk', 'sugar'), 'flour', 0.023202542),
 (('fruit/vegetable_juice', 'ham'), 'processed_cheese', 0.023127634),
 (('butter', 'root_vegetables'), 'herbs', 0.02304014),
 (('soda', 'ham'), 'processed_cheese', 0.022633638),
 (('soda', 'hamburger_meat'), 'Instant_food_products', 0.022627866),
 (('citrus_fruit', 'sugar'), 'flour', 0.022040429),
 (('bottled_beer', 'soda'), 'liquor', 0.02189085),
 (('processed_cheese', 'ham'), 'white_bread', 0.021692872),
 (('yogurt', 'sugar'), 'flour', 0.021522585),
 (('tropical_fruit', 'other_vegetables'), 'turkey', 0.021456005),
 (('other_vegetables', 'beef'), 'herbs', 0.021407435),
 (('white_bread', 'processed_cheese'), 'ham', 0.021362728),
 (('curd', 'root_vegetables'), 'herbs', 0.021005861),
 (('other_vegetables', 'ham'), 'processed_cheese', 0.020965746),
 (('root_vegetables', 'whipped/sour_cream'), 'herbs', 0.020788824),
 (('other_vegetables', 'root_vegetables'), 'herbs', 0.020782541),
 (('sugar', 'whipped/sour_cream'), 'baking_powder', 0.02058014),
 (('whole_milk', 'sugar'), 'rice', 0.020371588),
 (('root_vegetables', 'frozen_vegetables'), 'herbs', 0.02027719),
 (('whole_milk', 'Instant_food_products'), 'hamburger_meat', 0.020258738),
 (('citrus_fruit', 'root_vegetables'), 'herbs', 0.020241175)]

最小値の条件を下げたために、より多くの組み合わせを抽出していますが、CBoW の結果と大きな違いは無さそうです。

quill で DDL を実行

quillScala 用の DB ライブラリで、マクロを使ってコンパイル時に SQL や CQL(Cassandra)を組み立てるのが特徴となっています。

quill には Infix という機能が用意されており、これを使うと FOR UPDATE のような(quillが)未サポートの SQL 構文に対応したり、select 文を直接指定したりできるようですが、CREATE TABLE のような DDL(データ定義言語)の実行は無理そうでした。

そこで、API やソースを調べてみたところ、SQL を直接実行する probeexecuteAction という関数を見つけたので、これを使って CREATE TABLE を実行してみたいと思います。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180502/

はじめに

今回は Gradle を使ってビルド・実行し、DB には H2 を(インメモリーで)使います。

build.gradle
apply plugin: 'scala'
apply plugin: 'application'

mainClassName = 'sample.SampleApp'

repositories {
    jcenter()
}

dependencies {
    compile 'org.scala-lang:scala-library:2.12.6'
    compile 'io.getquill:quill-jdbc_2.12:2.4.2'

    runtime 'com.h2database:h2:1.4.192'
    runtime 'org.slf4j:slf4j-simple:1.8.0-beta2'
}

DB の接続設定は以下のようにしました。

ctx の部分は任意の文字列を用いることができ、H2JdbcContext を new する際の configPrefix 引数で指定します。

src/main/resources/application.conf
ctx.dataSourceClassName=org.h2.jdbcx.JdbcDataSource
ctx.dataSource.url="jdbc:h2:mem:sample"
ctx.dataSource.user=sa

1. probe・executeAction で DDL を実行

それでは、probeexecuteAction をそれぞれ使って CREATE TABLE を実行してみます。

JdbcContext における probe の戻り値は Try[Boolean]executeAction の戻り値は Long となっています。

sample1/src/main/scala/sample/SampleApp.scala
package sample

import io.getquill.{H2JdbcContext, SnakeCase}

case class Item(itemId: String, name: String)
case class Stock(stockId: String, itemId: String, qty: Int)

object SampleApp extends App {
  lazy val ctx = new H2JdbcContext(SnakeCase, "ctx")

  import ctx._

  // probe を使った CREATE TABLE の実行
  val r1 = probe("CREATE TABLE item(item_id VARCHAR(10) PRIMARY KEY, name VARCHAR(10))")
  println(s"create table1: $r1")

  // executeAction を使った CREATE TABLE の実行
  val r2 = executeAction("CREATE TABLE stock(stock_id VARCHAR(10) PRIMARY KEY, item_id VARCHAR(10), qty INT)")
  println(s"create table2: $r2")

  // item への insert
  println( run(query[Item].insert(lift(Item("item1", "A1")))) )
  println( run(query[Item].insert(lift(Item("item2", "B2")))) )

  // stock への insert
  println( run(query[Stock].insert(lift(Stock("stock1", "item1", 5)))) )
  println( run(query[Stock].insert(lift(Stock("stock2", "item2", 3)))) )

  // item の select
  println( run(query[Item]) )
  // stock の select
  println( run(query[Stock]) )

  // Infix を使った select
  val selectStocks = quote(
    infix"""SELECT stock_id AS "_1", name AS "_2", qty AS "_3"
            FROM stock s join item i on i.item_id = s.item_id""".as[Query[(String, String, Int)]]
  )
  println( run(selectStocks) )
}

実行結果は以下の通りで CREATE TABLE に成功しています。 probe の結果は Success(false) で executeAction の結果は 0 となりました。

実行結果
> cd sample1
> gradle run

・・・
[main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Starting...
[main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Start completed.

create table1: Success(false)
create table2: 0
1
1
1
1
List(Item(item1,A1), Item(item2,B2))
List(Stock(stock1,item1,5), Stock(stock2,item2,3))
List((stock1,A1,5), (stock2,B2,3))

・・・

2. モナドの利用

quill には IO モナドが用意されていたので、これを使って処理を組み立ててみます。

IO は run の代わりに runIO を使う事で取得でき、IO の結果は performIO で取得します。

probe の結果である Try[A]IO.fromTry を使う事で IO にできます。

また、クエリー query[A] では flatMap 等が使えるので for 内包表記で直接合成できましたが(selectStocks の箇所)、query[A].insert(・・・) は flatMap 等を使えなかったので runIO しています。(insertItemAndStock の箇所)

sample2/src/main/scala/sample/SampleApp.scala
package sample

import io.getquill.{H2JdbcContext, SnakeCase}

case class Item(itemId: String, name: String)
case class Stock(stockId: String, itemId: String, qty: Int)

object SampleApp extends App {
  lazy val ctx = new H2JdbcContext(SnakeCase, "ctx")

  import ctx._

  // CREATE TABLE
  val createTables = for {
    it <- probe("CREATE TABLE item(item_id VARCHAR(10) PRIMARY KEY, name VARCHAR(10))")
    st <- probe("CREATE TABLE stock(stock_id VARCHAR(10) PRIMARY KEY, item_id VARCHAR(10), qty INT)")
  } yield (it, st)

  // item と stock へ insert
  val insertItemAndStock = (itemId: String, name: String, stockId: String, qty: Int) => for {
    _ <- runIO( query[Item].insert(lift(Item(itemId, name))) )
    _ <- runIO( query[Stock].insert(lift(Stock(stockId, itemId, qty))) )
  } yield ()

  // stock と item の select(stock と該当する item をタプル化)
  val selectStocks = quote {
    for {
      s <- query[Stock]
      i <- query[Item] if i.itemId == s.itemId
    } yield (i, s)
  }

  // 処理の合成
  val proc = for {
    r1 <- IO.fromTry(createTables)
    _ <- insertItemAndStock("item1", "A1", "stock1", 5)
    _ <- insertItemAndStock("item2", "B2", "stock2", 3)
    r2 <- runIO(selectStocks)
  } yield (r1, r2)

  // 結果
  println( performIO(proc) )
  // トランザクションを適用する場合は以下のようにする
  //println( performIO(proc.transactional) )
}
実行結果
> cd sample2
> gradle run

・・・
[main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Starting...
[main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Start completed.

((false,false),List((Item(item1,A1),Stock(stock1,item1,5)), (Item(item2,B2),Stock(stock2,item2,3))))

・・・