CNN で輪郭の検出
画像内の物体の輪郭検出を CNN(畳み込みニューラルネット)で試してみました。
- Keras + Tensorflow
- Jupyter Notebook
ソースは http://github.com/fits/try_samples/tree/master/blog/20190114/
概要
今回は、画像をピクセル単位で輪郭か否かに分類する事(輪郭 = 1, 輪郭以外 = 0)で輪郭を検出できないか試しました。
そこで、教師データとして以下のような衣服単体の画像(jpg)と衣服の輪郭部分だけを白く塗りつぶした画像(png)を用意しました。
教師データを大量に用意するのは困難だったため、240x288 の画像 160 ファイルで学習を行っています。
学習
学習の処理は Jupyter Notebook 上で実行しました。
(1) 入力データの準備
まずは、入力画像(jpg)を読み込みます。(教師データの画像は img ディレクトリへ配置しています)
import glob import numpy as np from keras.preprocessing.image import load_img, img_to_array files = glob.glob('img/*.jpg') imgs = np.array([img_to_array(load_img(f)) for f in files]) imgs.shape
入力データの形状は以下の通りです。
imgs.shape 結果
(160, 288, 240, 3)
(2) ラベルデータの準備
輪郭画像(png)を読み込み、128 を境にして二値化(輪郭 = 1、輪郭以外 = 0)します。
import os th = 128 labels = np.array([img_to_array(load_img(f"{os.path.splitext(f)[0]}.png", color_mode = 'grayscale')) for f in files]) labels[labels < th] = 0 labels[labels >= th] = 1 labels.shape
ラベルデータの形状は以下の通りです。
labels.shape 結果
(160, 288, 240, 1)
(3) CNN モデル
どのようなネットワーク構成が適しているのか分からなかったので、セマンティックセグメンテーション等で用いられている Encoder-Decoder の構成を参考にしてみました。
30x36 まで段階的に縮小して(Encoder)、元の大きさ 240x288 まで段階的に拡大する(Decoder)ようにしています。
最終層の活性化関数に sigmoid
を使って 0 ~ 1 の値となるようにしています。
損失関数は binary_crossentropy
を使うと進捗が遅そうに見えたので ※、代わりに mean_squared_error
を使っています。
※ 今回の場合、輪郭(= 1)に該当するピクセルの方が少なくなるため binary_crossentropy を使用する場合は fit の class_weight 引数で 調整する必要があったと思われます
また、参考のため mean_absolute_error
(mae) の値も出力するように metrics
で指定しています。
from keras.models import Model from keras.layers import Input, Dropout from keras.layers.convolutional import Conv2D, Conv2DTranspose from keras.layers.pooling import MaxPool2D from keras.layers.normalization import BatchNormalization input = Input(shape = imgs.shape[1:]) x = input x = BatchNormalization()(x) # Encoder x = Conv2D(16, 3, padding='same', activation = 'relu')(x) x = MaxPool2D()(x) x = BatchNormalization()(x) x = Dropout(0.3)(x) x = Conv2D(32, 3, padding='same', activation = 'relu')(x) x = Conv2D(32, 3, padding='same', activation = 'relu')(x) x = Conv2D(32, 3, padding='same', activation = 'relu')(x) x = MaxPool2D()(x) x = BatchNormalization()(x) x = Dropout(0.3)(x) x = Conv2D(64, 3, padding='same', activation = 'relu')(x) x = Conv2D(64, 3, padding='same', activation = 'relu')(x) x = Conv2D(64, 3, padding='same', activation = 'relu')(x) x = MaxPool2D()(x) x = BatchNormalization()(x) x = Dropout(0.3)(x) x = Conv2D(128, 3, padding='same', activation = 'relu')(x) x = Conv2D(128, 3, padding='same', activation = 'relu')(x) x = BatchNormalization()(x) x = Dropout(0.3)(x) # Decoder x = Conv2DTranspose(64, 3, strides = 2, padding='same', activation = 'relu')(x) x = Conv2D(64, 3, padding='same', activation = 'relu')(x) x = Conv2D(64, 3, padding='same', activation = 'relu')(x) x = Conv2D(64, 3, padding='same', activation = 'relu')(x) x = BatchNormalization()(x) x = Dropout(0.3)(x) x = Conv2DTranspose(32, 3, strides = 2, padding='same', activation = 'relu')(x) x = Conv2D(32, 3, padding='same', activation = 'relu')(x) x = Conv2D(32, 3, padding='same', activation = 'relu')(x) x = Conv2D(32, 3, padding='same', activation = 'relu')(x) x = BatchNormalization()(x) x = Dropout(0.3)(x) x = Conv2DTranspose(16, 3, strides = 2, padding='same', activation = 'relu')(x) x = Conv2D(16, 3, padding='same', activation = 'relu')(x) x = BatchNormalization()(x) x = Dropout(0.3)(x) output = Conv2D(1, 1, activation = 'sigmoid')(x) model = Model(inputs = input, outputs = output) model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae']) model.summary()
model.summary() 結果
Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 288, 240, 3) 0 _________________________________________________________________ batch_normalization_1 (Batch (None, 288, 240, 3) 12 _________________________________________________________________ conv2d_1 (Conv2D) (None, 288, 240, 16) 448 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 144, 120, 16) 0 _________________________________________________________________ batch_normalization_2 (Batch (None, 144, 120, 16) 64 _________________________________________________________________ dropout_1 (Dropout) (None, 144, 120, 16) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 144, 120, 32) 4640 _________________________________________________________________ conv2d_3 (Conv2D) (None, 144, 120, 32) 9248 _________________________________________________________________ conv2d_4 (Conv2D) (None, 144, 120, 32) 9248 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 72, 60, 32) 0 _________________________________________________________________ batch_normalization_3 (Batch (None, 72, 60, 32) 128 _________________________________________________________________ dropout_2 (Dropout) (None, 72, 60, 32) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 72, 60, 64) 18496 _________________________________________________________________ conv2d_6 (Conv2D) (None, 72, 60, 64) 36928 _________________________________________________________________ conv2d_7 (Conv2D) (None, 72, 60, 64) 36928 _________________________________________________________________ max_pooling2d_3 (MaxPooling2 (None, 36, 30, 64) 0 _________________________________________________________________ batch_normalization_4 (Batch (None, 36, 30, 64) 256 _________________________________________________________________ dropout_3 (Dropout) (None, 36, 30, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 36, 30, 128) 73856 _________________________________________________________________ conv2d_9 (Conv2D) (None, 36, 30, 128) 147584 _________________________________________________________________ batch_normalization_5 (Batch (None, 36, 30, 128) 512 _________________________________________________________________ dropout_4 (Dropout) (None, 36, 30, 128) 0 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 72, 60, 64) 73792 _________________________________________________________________ conv2d_10 (Conv2D) (None, 72, 60, 64) 36928 _________________________________________________________________ conv2d_11 (Conv2D) (None, 72, 60, 64) 36928 _________________________________________________________________ conv2d_12 (Conv2D) (None, 72, 60, 64) 36928 _________________________________________________________________ batch_normalization_6 (Batch (None, 72, 60, 64) 256 _________________________________________________________________ dropout_5 (Dropout) (None, 72, 60, 64) 0 _________________________________________________________________ conv2d_transpose_2 (Conv2DTr (None, 144, 120, 32) 18464 _________________________________________________________________ conv2d_13 (Conv2D) (None, 144, 120, 32) 9248 _________________________________________________________________ conv2d_14 (Conv2D) (None, 144, 120, 32) 9248 _________________________________________________________________ conv2d_15 (Conv2D) (None, 144, 120, 32) 9248 _________________________________________________________________ batch_normalization_7 (Batch (None, 144, 120, 32) 128 _________________________________________________________________ dropout_6 (Dropout) (None, 144, 120, 32) 0 _________________________________________________________________ conv2d_transpose_3 (Conv2DTr (None, 288, 240, 16) 4624 _________________________________________________________________ conv2d_16 (Conv2D) (None, 288, 240, 16) 2320 _________________________________________________________________ batch_normalization_8 (Batch (None, 288, 240, 16) 64 _________________________________________________________________ dropout_7 (Dropout) (None, 288, 240, 16) 0 _________________________________________________________________ conv2d_17 (Conv2D) (None, 288, 240, 1) 17 ================================================================= Total params: 576,541 Trainable params: 575,831 Non-trainable params: 710 _________________________________________________________________
(4) 学習
教師データが少ないため、全て学習で使う事にします。
実行例(441 ~ 480 エポック)
hist = model.fit(imgs, labels, initial_epoch = 440, epochs = 480, batch_size = 10)
Keras では fit を繰り返し呼び出すと学習を(続きから)再開できるので、40 エポックを何回か繰り返しました。(バッチサイズは 20 で始めて途中で 10 へ変えたりしています)
その場合、正しいエポックを出力するには initial_epoch
と epochs
の値を調整する必要があります ※。
※ initial_epoch を指定しなくても fit を繰り返し実行するだけで学習は継続されますが、 その場合は出力されるエポックの値がクリアされます(1 からのカウントとなる)
結果例
Epoch 441/480 160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126 Epoch 442/480 160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0125 Epoch 443/480 160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126 ・・・ Epoch 478/480 160/160 [=====・・・ - loss: 0.0045 - mean_absolute_error: 0.0116 Epoch 479/480 160/160 [=====・・・ - loss: 0.0046 - mean_absolute_error: 0.0117 Epoch 480/480 160/160 [=====・・・ - loss: 0.0044 - mean_absolute_error: 0.0115
(5) 確認
fit
の戻り値から mean_squared_error(loss)と mean_absolute_error の値の遷移をグラフ化してみます。
%matplotlib inline import matplotlib.pyplot as plt plt.rcParams['figure.figsize'] = (8, 4) plt.subplot(1, 2, 1) plt.plot(hist.history['loss']) plt.subplot(1, 2, 2) plt.plot(hist.history['mean_absolute_error'])
結果例(441 ~ 480 エポック)
(6) 検証
(a) 教師データ
教師データの入力画像と輪郭画像(ラベルデータ)、model.predict
の結果を並べて表示してみます。
def predict(index, s = 6.0): plt.rcParams['figure.figsize'] = (s, s) sh = imgs.shape[1:-1] # 輪郭の検出(予測処理) pred = model.predict(np.array([imgs[index]]))[0] pred *= 255 plt.subplot(1, 3, 1) # 入力画像の表示 plt.imshow(imgs[index].astype(int)) plt.subplot(1, 3, 2) # 輪郭画像(ラベルデータ)の表示 plt.imshow(labels[index].reshape(sh), cmap = 'gray') plt.subplot(1, 3, 3) # predict の結果表示 plt.imshow(pred.reshape(sh).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, 輪郭画像, model.predict 結果
概ね教師データに近い結果が出るようになっています。
(b) 教師データ以外
教師データとして使っていない画像に対して model.predict
を実施し、輪郭の検出を行ってみます。
def predict_eval(file, s = 4.0): plt.rcParams['figure.figsize'] = (s, s) img = img_to_array(load_img(file)) # 輪郭の検出(予測処理) pred = model.predict(np.array([img]))[0] pred *= 255 plt.subplot(1, 2, 1) # 入力画像の表示 plt.imshow(img.astype(int)) plt.subplot(1, 2, 2) # predict の結果表示 plt.imshow(pred.reshape(pred.shape[:-1]).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, model.predict 結果
所々で途切れたりしていますが、ある程度の輪郭は検出できているように見えます。
(7) 保存
学習したモデルを保存します。
model.save('model/c1_480.h5')
輪郭検出
学習済みモデルを使って輪郭検出を行う処理をスクリプト化してみました。
predict_contours.py
import sys import os import glob import numpy as np from keras.preprocessing.image import load_img, img_to_array from keras.models import load_model import cv2 model_file = sys.argv[1] img_files = sys.argv[2] dest_dir = sys.argv[3] model = load_model(model_file) for f in glob.glob(img_files): img = img_to_array(load_img(f)) # 輪郭の検出 pred = model.predict(np.array([img]))[0] pred *= 255 file, ext = os.path.splitext(os.path.basename(f)) # 画像の保存 cv2.imwrite(f"{dest_dir}/{file}_predict.png", pred) print(f"done: {f}")
実行例
python predict_contours.py model/c1_480.h5 img_eval2/*.jpg result
480 エポックの学習モデルを使って、教師データに無いタイプの背景を使った画像(影の影響もある)に試してみました。
輪郭検出結果例
入力画像 | 処理結果 |
---|---|
こちらは難しかったようです。
なお、2つ目の画像は 120 エポックの学習モデルの方が良好な結果(輪郭がより多く検出されていた)でした。
MongoDB で条件に合致する子要素を抽出
MongoDB で指定の条件に合致する子要素のみを抽出する方法を調査してみました。
- MongoDB 4.0.4
はじめに、下記 3つのドキュメントが sample
コレクションへ登録されているとします。
ドキュメント内容
{ "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] } { "_id" : 2, "items" : [ { "color" : "red", "size" : "L" }, { "color" : "blue", "size" : "S" } ] } { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
ここで、items.color
が white
のものだけを抽出し、以下の結果を得る事を目指してみます。(items の中身が white のものだけを含むようにする)
目標とする検索結果
{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] } { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
(a) items.color で条件指定
まずは {"items.color": "white"}
の条件で find
した結果です。
white を持つドキュメントだけを抽出できましたが、ドキュメントの内容はそのままなので black 等の余計なものも含んでしまいます。
> db.sample.find({"items.color": "white"}) { "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] } { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
(b) $elemMatch 使用
次に $elemMatch
を使ってみます。
$elemMatch を find の query(第一引数)で使うか、projection(第二引数)で使うかで結果が変わります。
query で使う場合は先程の (a) と同じ結果になります。
query で使用
> db.sample.find({"items": {$elemMatch: {"color": "white"}}}) { "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] } { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
query の条件は指定せずに projection で $elemMatch を使った場合の結果は以下です。
全ドキュメントを対象に items の中身がフィルタリングされていますが、条件に合致する全ての子要素が抽出されるわけでは無く、(条件に合致する)先頭の要素しか含まれていません。
projection で使用1
> db.sample.find({}, {"items": {$elemMatch: {"color": "white"}}}) { "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] } { "_id" : 2 } { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }
query 条件を指定する事で不要なドキュメントを除く事はできますが、条件に合致する全ての子要素が抽出されない事に変わりはありません。
projection で使用2
> db.sample.find({"items.color": "white"}, {"items": {$elemMatch: {"color": "white"}}}) { "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] } { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }
このように、今回のようなドキュメントに対して $elemMatch を使うと、条件に合致する最初の子要素だけが抽出されるようです。(何らかの回避策があるのかもしれませんが)
(c) aggregate 使用
最後に aggregate
を使ってみます。
$unwind
を使うと配列の個々の要素を処理できるので、$match
で white のみに限定した後、$group
でグルーピングすれば良さそうです。
> db.sample.aggregate([ {$unwind: "$items"}, {$match: {"items.color": "white"}}, {$group: {_id: "$_id", "items": {$push: "$items"}}} ]) { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] } { "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
これで目指した結果は一応得られました。
なお、対象を最初に絞り込むようにしてソートを付けると以下のようになります。
> db.sample.aggregate([ {$match: {"items.color": "white"}}, {$unwind: "$items"}, {$match: {"items.color": "white"}}, {$group: {_id: "$_id", "items": {$push: "$items"}}}, {$sort: {_id: 1}} ]) { "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] } { "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
Scala のケースクラスに制約を持たせる
Scala のケースクラスで値に制約を持たせたい場合にどうするか。
例えば、以下のケースクラスで amount
の値を 0 以上となるように制限し、0 未満ならインスタンス化を失敗させる事を考えてみます。
case class Quantity(amount: Int)
使用した環境は以下
- Scala 2.12.7
ソースは http://github.com/fits/try_samples/tree/master/blog/20181009/
ケースクラスの値を制限
まず、最も単純なのは以下のような実装だと思います。
case class Quantity(amount: Int) { if (amount < 0) throw new IllegalArgumentException(s"amount($amount) < 0") }
これだと、例外が throw されてしまい関数プログラミングで扱い難いので Try[Quantity]
や Option[Quantity]
等を返すようにしたいところです。
そこで、以下のようにケースクラスを abstract
化して、コンパニオンオブジェクトへ生成関数を定義する方法を使ってみました。
sample.scala
import scala.util.{Try, Success, Failure} sealed abstract case class Quantity private (amount: Int) object Quantity { def apply(amount: Int): Try[Quantity] = if (amount >= 0) Success(new Quantity(amount){}) else Failure(new IllegalArgumentException(s"amount($amount) < 0")) } println(Quantity(1)) println(Quantity(0)) println(Quantity(-1)) // この方法では Quantity へ copy がデフォルト定義されないため // copy は使えません(error: value copy is not a member of this.Quantity) // // println(Quantity(1).map(_.copy(-1)))
実行結果は以下の通りです。
実行結果
> scala sample.scala Success(Quantity(1)) Success(Quantity(0)) Failure(java.lang.IllegalArgumentException: amount(-1) < 0)
上記 sample.scala
では、以下を直接呼び出せないようにしてケースクラスの勝手なインスタンス化を防止しています。
- (a) コンストラクタ(new)
- (b) コンパニオンオブジェクトへデフォルト定義される apply
- (c) ケースクラスへデフォルト定義される copy
そのために、下記 2点を実施しています。
- (1) コンストラクタの private 化 : (a) の防止
- (2) ケースクラスの abstract 化 : (b) (c) の防止
(1) コンストラクタの private 化
以下のように private
を付ける事でコンストラクタを private 化できます。
コンストラクタの private 化
case class Quantity private (amount: Int)
これで (a) new Quantity(・・・)
の実行を防止できますが、以下のように (b) の apply や (c) の copy を実行できてしまいます。
検証例
scala> case class Quantity private (amount: Int) defined class Quantity scala> new Quantity(1) <console>:14: error: constructor Quantity in class Quantity cannot be accessed in object $iw new Quantity(1) scala> Quantity(1) res1: Quantity = Quantity(1) scala> Quantity.apply(2) res2: Quantity = Quantity(2) scala> Quantity(3).copy(30) res3: Quantity = Quantity(30)
(2) ケースクラスの abstract 化
ケースクラスを abstract
化すると、通常ならデフォルト定義されるコンパニオンオブジェクトの apply やケースクラスの copy を防止できるようです。
そのため、(1) と組み合わせることで (a) ~ (c) を防止できます。
ケースクラスの abstract 化とコンストラクタの private 化
sealed abstract case class Quantity private (amount: Int)
以下のように Quantity.apply
は定義されなくなります。
検証例
scala> sealed abstract case class Quantity private (amount: Int) defined class Quantity scala> new Quantity(1){} <console>:14: error: constructor Quantity in class Quantity cannot be accessed in <$anon: Quantity> new Quantity(1){} ^ scala> Quantity.apply(1) <console>:14: error: value apply is not a member of object Quantity Quantity.apply(1)
このままだと何もできなくなるため、実際はコンパニオンオブジェクトへ生成用の関数が必要になります。
例
sealed abstract case class Quantity private (amount: Int) object Quantity { def create(amount: Int): Quantity = new Quantity(amount){} }
備考
今回の方法は、以下の書籍に記載されているような ADTs(algebraic data types)と Smart constructors をより安全に定義するために活用できると考えています。
Functional and Reactive Domain Modeling
- 作者: Debasish Ghosh
- 出版社/メーカー: Manning Publications
- 発売日: 2016/10/24
- メディア: ペーパーバック
- この商品を含むブログを見る
Kotlin の関数型プログラミング用ライブラリ Λrrow を試してみる
Kotlin で Scala の Scalaz や Cats のような関数型プログラミング用のライブラリを探していたところ、以下を見つけたので試してみました。
ソースは http://github.com/fits/try_samples/tree/master/blog/20180822/
はじめに
Λrrow は以下のような要素で構成されており、Haskell の型クラス・インスタンスのような仕組みを実現しているようです。
- Datatypes (Option, Either, Try, Kleisli, StateT, WriterT 等)
- Typeclasses (Functor, Applicative, Monad, Monoid, Eq, Show 等)
- Instances (OptionMonadInstance 等)
そして、その実現にはアノテーションプロセッサが活用されているようです。(@higherkind
や @instance
から補助的なコードを生成)
Option のソースコード
例えば、Option
のソースコードは以下のようになっていますが、OptionOf
の定義は github のソース上には見当たりません。
arrow/core/Option.kt
@higherkind sealed class Option<out A> : OptionOf<A> { ・・・ }
OptionOf はアノテーションプロセッサ ※ で生成したコード(以下)で定義されています。(OptionOf<A>
は Kind<ForOption, A>
の型エイリアス)
higherkind.arrow.core.Option.kt
package arrow.core class ForOption private constructor() { companion object } typealias OptionOf<A> = arrow.Kind<ForOption, A> @Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE") inline fun <A> OptionOf<A>.fix(): Option<A> = this as Option<A>
※ @higherkind に対する処理は arrow.higherkinds.HigherKindsFileGenerator に実装されているようです
ここで、ForOption
というクラスが定義されていますが、この For + 型名
はコンテナ型を表現するための型のようです。
ForOption がある事で Kind<Option<A>, A>
ではなく Kind<ForOption, A>
と定義できるようになっており、Functor<F>
や Monad<F>
における F の具体型として使われています。(例. Monad<ForOption>
)
Either のソースコード
次に Either
のソースコードも見てみます。
Either のように型パラメータが複数の場合はアノテーションプロセッサで生成されるコードの内容に多少の違いがあるようです。
arrow/core/Either.kt
@higherkind sealed class Either<out A, out B> : EitherOf<A, B> { ・・・ }
アノテーションプロセッサで生成されたコードは以下のように EitherOf
の他に EitherPartialOf
が型エイリアスとして定義されています。
higherkind.arrow.core.Either.kt
package arrow.core class ForEither private constructor() { companion object } typealias EitherOf<A, B> = arrow.Kind2<ForEither, A, B> typealias EitherPartialOf<A> = arrow.Kind<ForEither, A> @Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE") inline fun <A, B> EitherOf<A, B>.fix(): Either<A, B> = this as Either<A, B>
また、上記で定義されている ForEither クラスとは別に arrow-instances-core
モジュールには ForEither 関数
も用意されています。
arrow/instances/either.kt
・・・ class EitherContext<L> : EitherMonadErrorInstance<L>, EitherTraverseInstance<L>, EitherSemigroupKInstance<L> { override fun <A, B> Kind<EitherPartialOf<L>, A>.map(f: (A) -> B): Either<L, B> = fix().map(f) } class EitherContextPartiallyApplied<L> { infix fun <A> extensions(f: EitherContext<L>.() -> A): A = f(EitherContext()) } fun <L> ForEither(): EitherContextPartiallyApplied<L> = EitherContextPartiallyApplied()
この ForEither 関数は extensions
を呼び出す際に使用する事になります。
サンプルコード
Option と Either をそれぞれ使ったサンプルを作成してみます。
(a) Option の利用
Option の Some
や None
の作成にはいくつかの方法が用意されています。
Monadbinding
を使うと Haskell の do 記法や Scala の for 式のようにモナドを処理できるようです。(Kotlin の coroutines 機能で実現)
binding の呼び出し方はいくつか考えられますが、ForOption extensions { ・・・ }
を使うのが基本のようです。
ForOption extensions へ渡す処理内では this が OptionContext
となるため、OptionContext の処理を呼び出せるようになっています。
また、binding へ渡す処理内では MonadContinuation<ForOption, *>
が this となります。
bind
は MonadContinuation 内で suspend fun <B> Kind<F, B>.bind(): B
と定義されており、Kind<F, B> から B の値を取り出す処理となっています。
src/main/kotlin/App.kt
import arrow.core.* import arrow.instances.* import arrow.typeclasses.binding fun main(args: Array<String>) { // Some val d1: Option<Int> = Option.just(10) val d2: Option<Int> = 5.some() val d3: Option<Int> = Some(2) // None val d4: Option<Int> = Option.empty() val d5: Option<Int> = none() val d6: Option<Int> = None // Some(15) val r1 = d1.flatMap { a -> d2.map { b -> a + b } } println(r1) // Some(17) val r2 = Option.monad().binding { // this: MonadContinuation<ForOption, *> val a = d1.bind() // 10 val b = d2.bind() // 5 val c = d3.bind() // 2 a + b + c } println(r2) // Some(17) val r3 = ForOption extensions { // this: OptionContext println(this) // arrow.instances.OptionContext@3ffc5af1 binding { // this: MonadContinuation<ForOption, *> println(this) // arrow.typeclasses.MonadContinuation@26653222 val a = d1.bind() val b = d2.bind() val c = d3.bind() a + b + c } } println(r3) // Some(17) val r4 = OptionContext.binding { val a = d1.bind() val b = d2.bind() val c = d3.bind() a + b + c } println(r4) // None val r5 = Option.monad().binding { val a = d1.bind() val b = d4.bind() a + b } println(r5) // None val r6 = ForOption extensions { binding { val a = d5.bind() val b = d2.bind() val c = d6.bind() a + b + c } } println(r6) }
Option のような基本的な型を使うだけであれば、依存ライブラリとして arrow-instances-core
を指定するだけで問題なさそうです。
そのため、Gradle ビルド定義ファイルは下記のような内容になります。
ここでは The feature "coroutines" is experimental (see: https://kotlinlang.org/docs/diagnostics/experimental-coroutines)
警告ログを出力しないように coroutines を有効化しています。
build.gradle (Gradle ビルド定義)
plugins { id 'org.jetbrains.kotlin.jvm' version '1.2.51' id 'application' } mainClassName = 'AppKt' // 実行するクラス名 repositories { jcenter() } dependencies { compile 'org.jetbrains.kotlin:kotlin-stdlib-jdk8' compile 'io.arrow-kt:arrow-instances-core:0.7.3' } // coroutines の有効化(ビルド時の警告を抑制) kotlin { experimental { coroutines 'enable' } }
実行結果は以下の通りです。
実行結果
> gradle run ・・・ Some(15) Some(17) arrow.instances.OptionContext@3ffc5af1 arrow.typeclasses.MonadContinuation@26653222 Some(17) Some(17) None None
(b) Either の利用
Either の場合、extensions の呼び出し方が Option とは多少異なります。
Option の場合は ForOption の extensions を呼び出しましたが、Either の場合は ForEither に extensions は用意されておらず、代わりに EitherContextPartiallyApplied<L>
クラス内に定義されています。
EitherContextPartiallyAppliedForEither 関数
で取得できるので、これを使って ForEither<Left 側の型>() extensions { ・・・ }
のようにします。
ForEither<String>() extensions
へ渡す処理内の this は EitherContext<String>
で、binding へ渡す処理内の this は MonadContinuation<EitherPartialOf<String>, *>
となります。
src/main/kotlin/App.kt
import arrow.core.* import arrow.instances.* import arrow.typeclasses.binding fun main(args: Array<String>) { // Right val d1: Either<String, Int> = Either.right(10) val d2: Either<String, Int> = 5.right() val d3: Either<String, Int> = Right(2) // Left val d4: Either<String, Int> = Either.left("error data") // Right(b=15) val r1 = d1.flatMap { a -> d2.map { b -> a + b } } println(r1) // Right(b=17) val r2 = Either.monad<String>().binding { // this: MonadContinuation<EitherPartialOf<String>, *> val a = d1.bind() val b = d2.bind() val c = d3.bind() a + b + c } println(r2) // Right(b=17) // ForEither<String>() 関数の呼び出し val r3 = ForEither<String>() extensions { // this: EitherContext<String> binding { // this: MonadContinuation<EitherPartialOf<String>, *> val a = d1.bind() val b = d2.bind() val c = d3.bind() a + b + c } } println(r3) // Left(a=error data) val r4 = ForEither<String>() extensions { binding { val a = d1.bind() val b = d4.bind() val c = d3.bind() a + b + c } } println(r4) }
(a) と同じ内容の build.gradle を使って実行します。
実行結果
> gradle run ・・・ Right(b=15) Right(b=17) Right(b=17) Left(a=error data)
TypeScript で funfix を使用 - tsc, FuseBox
funfix は JavaScript, TypeScript, Flow の関数型プログラミング用ライブラリで、Fantasy Land や Static Land ※ に準拠し Scala の Option, Either, Try, Future 等と同等の型が用意されているようです。
※ JavaScript 用に Monoid や Monad 等の代数的構造に関する仕様を定義したもの
今回は Option を使った単純な処理を TypeScript で実装し Node.js で実行してみます。
ソースは http://github.com/fits/try_samples/tree/master/blog/20180730/
はじめに
Option
を使った下記サンプルをコンパイルして実行します。
サンプルソース
import { Option, Some } from 'funfix' const f = (ma, mb) => ma.flatMap(a => mb.map(b => `${a} + ${b} = ${a + b}`)) const d1 = Some(10) const d2 = Some(2) console.log( d1 ) console.log('-----') console.log( f(d1, d2) ) console.log( f(d1, Option.none()) ) console.log('-----') console.log( f(d1, d2).getOrElse('none') ) console.log( f(d1, Option.none()).getOrElse('none') )
ビルドと実行
上記ソースファイルを以下の 2通りでビルドして実行してみます。
- (a) tsc 利用
- (b) FuseBox 利用
(a) tsc を利用する場合
tsc コマンドを使って TypeScript のソースをコンパイルします。
まずは typescript と funfix モジュールをそれぞれインストールします。
typescript インストール
> npm install --save-dev typescript
funfix インストール
> npm install --save funfix
この状態で sample.ts
ファイルをコンパイルしてみると、型関係のエラーが出るものの sample.js
は正常に作られました。
コンパイル1
> tsc sample.ts node_modules/funfix-core/dist/disjunctions.d.ts:775:14 - error TS2416: Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'. Type 'undefined' is not assignable to type 'never'. 775 readonly value: undefined; ~~~~~ node_modules/funfix-effect/dist/eval.d.ts:256:42 - error TS2304: Cannot find name 'Iterable'. 256 static sequence<A>(list: Eval<A>[] | Iterable<Eval<A>>): Eval<A[]>; ~~~~~~~~ ・・・
sample.js を実行してみると特に問題無く動作します。
実行1
> node sample.js TSome { _isEmpty: false, value: 10 } ----- TSome { _isEmpty: false, value: '10 + 2 = 12' } TNone { _isEmpty: true, value: undefined } ----- 10 + 2 = 12 none
これで一応は動いたわけですが、コンパイル時にエラーが出るというのも望ましい状態ではないので、エラー部分を解決してみます。
他にも方法があるかもしれませんが、ここでは以下のように対応します。
- (1)
Property 'value' in type 'TNone' ・・・ 'Option<never>'
のエラーに対して tsc 実行時に--strictNullChecks
オプションを指定して対応 - (2)
Cannot find name 'Iterable'
等のエラーに対して@types/node
をインストールして対応
strictNullChecks は tsc の実行時オプションで指定する以外にも設定ファイル tsconfig.json
で設定する事もできるので、ここでは tsconfig.json ファイルを使います。
(1) tsconfig.json
{ "compilerOptions": { "strictNullChecks": true } }
次に @types/node をインストールします。
@types/node には Node.js で実行するための型定義(Node.js 依存の API 等)が TypeScript 用に定義されています。
(2) @types/node インストール
> npm install --save-dev @types/node
この状態で tsc を実行すると先程のようなエラーは出なくなりました。(tsconfig.json を適用するため tsc コマンドへ引数を指定していない点に注意)
コンパイル2
> tsc
実行結果にも差異はありません。
実行2
> node sample.js TSome { _isEmpty: false, value: 10 } ----- TSome { _isEmpty: false, value: '10 + 2 = 12' } TNone { _isEmpty: true, value: undefined } ----- 10 + 2 = 12 none
最終的な package.json の内容は以下の通りです。
package.json
{ "name": "sample", "version": "1.0.0", "devDependencies": { "@types/node": "^10.5.4", "typescript": "^2.9.2" }, "dependencies": { "funfix": "^7.0.1" } }
(b) FuseBox を利用する場合
次に、モジュールバンドラーの FuseBox を使用してみます。(以降は (a) とは異なるディレクトリで実施)
なお、ここでは npm の代わりに yarn
を使っていますが、npm でも特に問題はありません。
yarn のインストール例(npm 使用)
> npm install -g yarn
(b-1) 型チェック無し
まずは typescript, fuse-box, funfix をそれぞれインストールしておきます。
typescript と fuse-box インストール
> yarn add --dev typescript fuse-box
funfix インストール
> yarn add funfix
FuseBox ではビルド定義を JavaScript のコードで記載します。 とりあえずは必要最小限の設定を行いました。
bundle で指定した名称が init の $name
に適用されるため、*.ts
のコンパイル結果と依存モジュールの内容をバンドルして bundle.js
へ出力する事になります。
なお、>
でロード時に実行する(コードを記載した)ファイルを指定します。
fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box') const fuse = FuseBox.init({ output: '$name.js' }) fuse.bundle('bundle').instructions('> *.ts') fuse.run()
上記スクリプトを実行してビルド(TypeScript のコンパイルとバンドル)を行います。
ビルド
> node fuse.js --- FuseBox 3.4.0 --- → Generating recommended tsconfig.json: ・・・\sample_fusebox1\tsconfig.json → Typescript script target: ES7 -------------------------- Bundle "bundle" sample.js └── (1 files, 700 Bytes) default └── funfix-core 34.4 kB (1 files) └── funfix-effect 43.1 kB (1 files) └── funfix-exec 79.5 kB (1 files) └── funfix 1 kB (1 files) size: 158.7 kB in 765ms
初回実行時にデフォルト設定の tsconfig.json が作られました。(tsconfig.json が存在しない場合)
tsc の時のような型関係のエラーは出ていませんが、これは FuseBox がデフォルトで TypeScript の型チェックをしていない事が原因のようです。
型チェックを実施するには fuse-box-typechecker
プラグインを使う必要がありそうです。
実行
> node bundle.js TSome { _isEmpty: false, value: 10 } ----- TSome { _isEmpty: false, value: '10 + 2 = 12' } TNone { _isEmpty: true, value: undefined } ----- 10 + 2 = 12 none
package.json の内容は以下の通りです。
package.json
{ "name": "sample_fusebox1", "version": "1.0.0", "main": "bundle.js", "license": "MIT", "devDependencies": { "fuse-box": "^3.4.0", "typescript": "^2.9.2" }, "dependencies": { "funfix": "^7.0.1" } }
(b-2) 型チェック有り
TypeScript の型チェックを行うようにしてみます。
まずは、(b-1) と同じ構成に fuse-box-typechecker
プラグインを加えます。
fuse-box-typechecker を追加インストール
> yarn add --dev fuse-box-typechecker
次に、fuse.js へ fuse-box-typechecker プラグインの設定を追加します。
TypeChecker で型チェックにエラーがあった場合、例外が throw されるようにはなっていないため、ここではエラーがあった場合に Error を throw して fuse.run() を実行しないようにしてみました。
ただし、こうすると tsconfig.json を予め用意しておく必要があります。(TypeChecker に tsconfig.json が必要)
fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box') const {TypeChecker} = require('fuse-box-typechecker') // fuse-box-typechecker の設定 const testSync = TypeChecker({ tsConfig: './tsconfig.json' }) const fuse = FuseBox.init({ output: '$name.js' }) fuse.bundle('bundle').instructions('> *.ts') testSync.runPromise() .then(n => { if (n != 0) { // 型チェックでエラーがあった場合 throw new Error(n) } // 型チェックでエラーがなかった場合 return fuse.run() }) .catch(console.error)
これで、ビルド時に (a) と同様の型エラーが出るようになりました。
ビルド1
> node fuse.js ・・・ --- FuseBox 3.4.0 --- Typechecker plugin(promisesync) . Time:Sun Jul 29 2018 12:40:47 GMT+0900 (GMT+09:00) File errors: └── .\node_modules\funfix-core\dist\disjunctions.d.ts | ・・・\sample_fusebox2\node_modules\funfix-core\dist\disjunctions.d.ts (775,14) (Error:TS2416) Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'. Type 'undefined' is not assignable to type 'never'. Errors:1 └── Options: 0 └── Global: 0 └── Syntactic: 0 └── Semantic: 1 └── TsLint: 0 Typechecking time: 4116ms Quitting typechecker ・・・
ここで、Iterable
の型エラーが出ていないのは fuse-box-typechecker のインストール時に @types/node もインストールされているためです。
(a) と同様に strictNullChecks の設定を tsconfig.json へ追記して、このエラーを解決します。
tsconfig.json へ strictNullChecks の設定を追加
{ "compilerOptions": { "module": "commonjs", "target": "ES7", ・・・ "strictNullChecks": true } }
これでビルドが成功するようになりました。
ビルド2
> node fuse.js ・・・ Typechecker name: undefined Typechecker basepath: ・・・\sample_fusebox2 Typechecker tsconfig: ・・・\sample_fusebox2\tsconfig.json --- FuseBox 3.4.0 --- Typechecker plugin(promisesync) . Time:Sun Jul 29 2018 12:44:57 GMT+0900 (GMT+09:00) All good, no errors :-) Typechecking time: 4103ms Quitting typechecker killing worker → Typescript config file: \tsconfig.json → Typescript script target: ES7 -------------------------- Bundle "bundle" sample.js └── (1 files, 700 Bytes) default └── funfix-core 34.4 kB (1 files) └── funfix-effect 43.1 kB (1 files) └── funfix-exec 79.5 kB (1 files) └── funfix 1 kB (1 files) size: 158.7 kB in 664ms
実行結果
> node bundle.js TSome { _isEmpty: false, value: 10 } ----- TSome { _isEmpty: false, value: '10 + 2 = 12' } TNone { _isEmpty: true, value: undefined } ----- 10 + 2 = 12 none
package.json の内容は以下の通りです。
package.json
{ "name": "sample_fusebox2", "version": "1.0.0", "main": "index.js", "license": "MIT", "devDependencies": { "fuse-box": "^3.4.0", "fuse-box-typechecker": "^2.10.0", "typescript": "^2.9.2" }, "dependencies": { "funfix": "^7.0.1" } }
Word2Vec を用いた併売の分析 - gensim
「トピックモデルを用いた併売の分析」ではトピックモデルによる併売の分析を試しましたが、今回は gensim の Word2Vec で試してみました。
ソースは http://github.com/fits/try_samples/tree/master/blog/20180617/
はじめに
データセット
これまで は適当に作ったデータセットを使っていましたが、今回は R の Groceries データセット ※ をスペース区切りのテキストファイル(groceries.txt)にして使います。(商品名にスペースを含む場合は代わりに _
を使っています)
※ ある食料雑貨店における 30日間の POS データ
groceries.txt
citrus_fruit semi-finished_bread margarine ready_soups tropical_fruit yogurt coffee whole_milk pip_fruit yogurt cream_cheese_ meat_spreads other_vegetables whole_milk condensed_milk long_life_bakery_product whole_milk butter yogurt rice abrasive_cleaner rolls/buns other_vegetables UHT-milk rolls/buns bottled_beer liquor_(appetizer) pot_plants whole_milk cereals ・・・ cooking_chocolate chicken citrus_fruit other_vegetables butter yogurt frozen_dessert domestic_eggs rolls/buns rum cling_film/bags semi-finished_bread bottled_water soda bottled_beer chicken tropical_fruit other_vegetables vinegar shopping_bags
R によるアソシエーションルールの抽出結果
参考のため、まずは R を使って Groceries データセットを apriori
で処理しました。
リフト値を優先するため、支持度 (supp
) と確信度 (conf
) を低めの値にしています。
groceries_apriori.R
library(arules) data(Groceries) params <- list(supp = 0.001, conf = 0.1) rules <- apriori(Groceries, parameter = params) inspect(head(sort(rules, by = "lift"), 10))
実行結果
> Rscript groceries_apriori.R ・・・ lhs rhs support confidence lift count [1] {bottled beer, red/blush wine} => {liquor} 0.001931876 0.3958333 35.71579 19 [2] {hamburger meat, soda} => {Instant food products} 0.001220132 0.2105263 26.20919 12 [3] {ham, white bread} => {processed cheese} 0.001931876 0.3800000 22.92822 19 [4] {root vegetables, other vegetables, whole milk, yogurt} => {rice} 0.001321810 0.1688312 22.13939 13 [5] {bottled beer, liquor} => {red/blush wine} 0.001931876 0.4130435 21.49356 19 [6] {Instant food products, soda} => {hamburger meat} 0.001220132 0.6315789 18.99565 12 [7] {curd, sugar} => {flour} 0.001118454 0.3235294 18.60767 11 [8] {soda, salty snack} => {popcorn} 0.001220132 0.1304348 18.06797 12 [9] {sugar, baking powder} => {flour} 0.001016777 0.3125000 17.97332 10 [10] {processed cheese, white bread} => {ham} 0.001931876 0.4634146 17.80345 19
bottled beer と red/blush wine で liquor が同時に買われやすい、hamburger meat と soda で Instant food products が同時に買われやすいという結果(アソシエーションルール)が出ています。
Word2Vec の適用
それでは groceries.txt を gensim の Word2Vec で処理してみます。
とりあえず iter
を 500 に min_count
を 1 にしてみました。
なお、購入品目の多い POS データを処理する場合は window
パラメータを大きめにすべきかもしれません。(今回はデフォルト値の 5)
今回は Jupyter Notebook で実行しています。
Word2Vec モデルの構築
from gensim.models import word2vec sentences = word2vec.LineSentence('groceries.txt') model = word2vec.Word2Vec(sentences, iter = 500, min_count = 1)
類似品の算出
まずは、wv.most_similar
で類似単語(商品)を抽出してみます。
pork の類似単語
model.wv.most_similar('pork')
[('turkey', 0.5547687411308289), ('ham', 0.49448296427726746), ('pip_fruit', 0.46879759430885315), ('tropical_fruit', 0.4383287727832794), ('butter', 0.43373265862464905), ('frankfurter', 0.4334157109260559), ('root_vegetables', 0.4249211549758911), ('citrus_fruit', 0.4246293306350708), ('chicken', 0.42378148436546326), ('sausage', 0.41153857111930847)]
微妙なものも含んでいますが、それなりの結果になっているような気もします。
most_similar
はベクトル的に類似している単語を抽出するため、POS データを処理する場合は競合や代用品の抽出に使えるのではないかと思います。
併売の分析
併売の商品はお互いに類似していないと思うので most_similar は役立ちそうにありませんが、それでも何らかの関係性はありそうな気がします。
そこで、指定した単語群の中心となる単語を抽出する predict_output_word
を使えないかと思い、R で抽出したアソシエーションルールの組み合わせで試してみました。
predict_output_word の検証
bottled_beer と red/blush_wine
model.predict_output_word(['bottled_beer', 'red/blush_wine'])
[('liquor', 0.22384332), ('prosecco', 0.04933687), ('sparkling_wine', 0.0345262), ・・・]
R の結果に出てた liquor が先頭(確率が最大)に来ています。
bottled_beer と red/blush_wine
model.predict_output_word(['hamburger_meat', 'soda'])
[('Instant_food_products', 0.054281656), ('canned_vegetables', 0.029985178), ('pasta', 0.025487985), ・・・]
ここでも R の結果に出てた Instant_food_products が先頭に来ています。
ham と white_bread
model.predict_output_word(['ham', 'white_bread'])
[('processed_cheese', 0.20990367), ('sweet_spreads', 0.024131883), ('spread_cheese', 0.023222428), ・・・]
こちらも同様です。
root_vegetables と other_vegetables と whole_milk と yogurt
model.predict_output_word(['root_vegetables', 'other_vegetables', 'whole_milk', 'yogurt'])
[('herbs', 0.024541182), ('liver_loaf', 0.019327056), ('turkey', 0.01775743), ('onions', 0.01760579), ('specialty_cheese', 0.014991459), ('packaged_fruit/vegetables', 0.014529809), ('spread_cheese', 0.012931713), ('meat', 0.012434797), ('beef', 0.011924307), ('butter_milk', 0.011828974)]
R の結果にあった rice はこの中には含まれていません。
curd と sugar
model.predict_output_word(['curd', 'sugar'])
[('flour', 0.076272935), ('pudding_powder', 0.055790607), ('baking_powder', 0.026003197), ・・・]
R の結果に出てた flour (小麦粉) が先頭に来ています。
soda と salty_snack
model.predict_output_word(['soda', 'salty_snack'])
[('popcorn', 0.05830234), ('nut_snack', 0.046429735), ('chewing_gum', 0.0213278), ・・・]
こちらも同様です。
sugar と baking_powder
model.predict_output_word(['sugar', 'baking_powder'])
[('flour', 0.11954326), ('cooking_chocolate', 0.046284538), ('pudding_powder', 0.03714784), ・・・]
こちらも同様です。
以上のように、少なくとも 2品を指定した場合の predict_output_word
の結果は R で抽出したアソシエーションルールに合致しているようです。
Word2Vec のパラメータに左右されるのかもしれませんが、この結果を見る限りでは predict_output_word を 3品の併売の組み合わせ抽出に使えるかもしれない事が分かりました。
3品の併売
次に predict_output_word で 2品に対する 1品を確率の高い順に抽出してみました。
なお、ここでは 3品の組み合わせの購入数が 10 未満のものは除外するようにしています。
from collections import Counter import itertools # 3品の組み合わせのカウント tri_counter = Counter([c for ws in sentences for c in itertools.combinations(sorted(ws), 3)]) # 2品の組み合わせを作成 pairs = itertools.combinations(model.wv.vocab.keys(), 2) sorted([ (p, item, prob) for p in pairs for item, prob in model.predict_output_word(p) if prob >= 0.05 and tri_counter[tuple(sorted([p[0], p[1], item]))] >= 10 ], key = lambda x: -x[2])
[(('bottled_beer', 'red/blush_wine'), 'liquor', 0.22384332), (('white_bread', 'ham'), 'processed_cheese', 0.20990367), (('bottled_beer', 'liquor'), 'red/blush_wine', 0.16274776), (('sugar', 'baking_powder'), 'flour', 0.11954326), (('curd', 'sugar'), 'flour', 0.076272935), (('margarine', 'sugar'), 'flour', 0.07422828), (('flour', 'sugar'), 'baking_powder', 0.07345509), (('sugar', 'whipped/sour_cream'), 'flour', 0.072731614), (('rolls/buns', 'hamburger_meat'), 'Instant_food_products', 0.06818052), (('sugar', 'root_vegetables'), 'flour', 0.0641469), (('tropical_fruit', 'white_bread'), 'processed_cheese', 0.061861355), (('soda', 'ham'), 'processed_cheese', 0.06138085), (('white_bread', 'processed_cheese'), 'ham', 0.061199907), (('whole_milk', 'ham'), 'processed_cheese', 0.059773713), (('beef', 'root_vegetables'), 'herbs', 0.059243686), (('sugar', 'whipped/sour_cream'), 'baking_powder', 0.05871357), (('soda', 'salty_snack'), 'popcorn', 0.05830234), (('soda', 'popcorn'), 'salty_snack', 0.05819882), (('red/blush_wine', 'liquor'), 'bottled_beer', 0.057226427), (('flour', 'baking_powder'), 'sugar', 0.05517209), (('soda', 'hamburger_meat'), 'Instant_food_products', 0.054281656), (('processed_cheese', 'ham'), 'white_bread', 0.053193364), (('other_vegetables', 'ham'), 'processed_cheese', 0.052585844)]
R で抽出したアソシエーションルールと同じ様な結果が出ており、それなりの結果が出ているように思います。
skip-gram の場合
gensim の Word2Vec はデフォルトで CBoW を使うようですので、skip-gram の場合にどうなるかも簡単に確認してみました。
skip-gram の使用
model = word2vec.Word2Vec(sentences, iter = 500, min_count = 1, sg = 1)
まずは predict_output_word の結果をいくつか見てみます。
先頭(確率が最大のもの)は変わらないようですが、CBoW よりも確率の値が全体的に低くなっているようです。
bottled_beer と red/blush_wine
model.predict_output_word(['bottled_beer', 'red/blush_wine'])
[('liquor', 0.076620705), ('prosecco', 0.030791236), ('liquor_(appetizer)', 0.027123762), ・・・]
hamburger_meat と soda
model.predict_output_word(['hamburger_meat', 'soda'])
[('Instant_food_products', 0.022627866), ('pasta', 0.018009944), ('canned_vegetables', 0.01685342), ・・・]
root_vegetables と other_vegetables と whole_milk と yogurt
model.predict_output_word(['root_vegetables', 'other_vegetables', 'whole_milk', 'yogurt'])
[('herbs', 0.015105391), ('turkey', 0.014365919), ('rice', 0.01316431), ・・・]
ここでは、CBoW で 10番以内に入っていなかった rice が入っています。
次に、先程と同様に predict_output_word で 3品の組み合わせを確率順に抽出してみます。
確率の値が全体的に下がっているため、最小値の条件を 0.02 へ変えています。
predict_output_word を使った 3品の組み合わせ抽出
・・・ sorted([ (p, item, prob) for p in pairs for item, prob in model.predict_output_word(p) if prob >= 0.02 and tri_counter[tuple(sorted([p[0], p[1], item]))] >= 10 ], key = lambda x: -x[2])
[(('bottled_beer', 'red/blush_wine'), 'liquor', 0.076620705), (('bottled_beer', 'liquor'), 'red/blush_wine', 0.0712179), (('white_bread', 'ham'), 'processed_cheese', 0.039820198), (('red/blush_wine', 'liquor'), 'bottled_beer', 0.031292748), (('sugar', 'baking_powder'), 'flour', 0.030803043), (('sugar', 'whipped/sour_cream'), 'flour', 0.029322423), (('margarine', 'sugar'), 'flour', 0.027827), (('beef', 'root_vegetables'), 'herbs', 0.02740662), (('curd', 'sugar'), 'flour', 0.025570681), (('flour', 'sugar'), 'baking_powder', 0.025403246), (('tropical_fruit', 'root_vegetables'), 'turkey', 0.025329975), (('whole_milk', 'ham'), 'processed_cheese', 0.024535457), (('rolls/buns', 'hamburger_meat'), 'Instant_food_products', 0.02427808), (('flour', 'baking_powder'), 'sugar', 0.023779714), (('tropical_fruit', 'white_bread'), 'processed_cheese', 0.023528077), (('sugar', 'root_vegetables'), 'flour', 0.023394365), (('soda', 'salty_snack'), 'popcorn', 0.02322538), (('whole_milk', 'sugar'), 'flour', 0.023202542), (('fruit/vegetable_juice', 'ham'), 'processed_cheese', 0.023127634), (('butter', 'root_vegetables'), 'herbs', 0.02304014), (('soda', 'ham'), 'processed_cheese', 0.022633638), (('soda', 'hamburger_meat'), 'Instant_food_products', 0.022627866), (('citrus_fruit', 'sugar'), 'flour', 0.022040429), (('bottled_beer', 'soda'), 'liquor', 0.02189085), (('processed_cheese', 'ham'), 'white_bread', 0.021692872), (('yogurt', 'sugar'), 'flour', 0.021522585), (('tropical_fruit', 'other_vegetables'), 'turkey', 0.021456005), (('other_vegetables', 'beef'), 'herbs', 0.021407435), (('white_bread', 'processed_cheese'), 'ham', 0.021362728), (('curd', 'root_vegetables'), 'herbs', 0.021005861), (('other_vegetables', 'ham'), 'processed_cheese', 0.020965746), (('root_vegetables', 'whipped/sour_cream'), 'herbs', 0.020788824), (('other_vegetables', 'root_vegetables'), 'herbs', 0.020782541), (('sugar', 'whipped/sour_cream'), 'baking_powder', 0.02058014), (('whole_milk', 'sugar'), 'rice', 0.020371588), (('root_vegetables', 'frozen_vegetables'), 'herbs', 0.02027719), (('whole_milk', 'Instant_food_products'), 'hamburger_meat', 0.020258738), (('citrus_fruit', 'root_vegetables'), 'herbs', 0.020241175)]
最小値の条件を下げたために、より多くの組み合わせを抽出していますが、CBoW の結果と大きな違いは無さそうです。
quill で DDL を実行
quill は Scala 用の DB ライブラリで、マクロを使ってコンパイル時に SQL や CQL(Cassandra)を組み立てるのが特徴となっています。
quill には Infix
という機能が用意されており、これを使うと FOR UPDATE
のような(quillが)未サポートの SQL 構文に対応したり、select 文を直接指定したりできるようですが、CREATE TABLE
のような DDL(データ定義言語)の実行は無理そうでした。
そこで、API やソースを調べてみたところ、SQL を直接実行する probe
や executeAction
という関数を見つけたので、これを使って CREATE TABLE を実行してみたいと思います。
ソースは http://github.com/fits/try_samples/tree/master/blog/20180502/
はじめに
今回は Gradle を使ってビルド・実行し、DB には H2 を(インメモリーで)使います。
build.gradle
apply plugin: 'scala' apply plugin: 'application' mainClassName = 'sample.SampleApp' repositories { jcenter() } dependencies { compile 'org.scala-lang:scala-library:2.12.6' compile 'io.getquill:quill-jdbc_2.12:2.4.2' runtime 'com.h2database:h2:1.4.192' runtime 'org.slf4j:slf4j-simple:1.8.0-beta2' }
DB の接続設定は以下のようにしました。
ctx
の部分は任意の文字列を用いることができ、H2JdbcContext
を new する際の configPrefix 引数で指定します。
src/main/resources/application.conf
ctx.dataSourceClassName=org.h2.jdbcx.JdbcDataSource ctx.dataSource.url="jdbc:h2:mem:sample" ctx.dataSource.user=sa
1. probe・executeAction で DDL を実行
それでは、probe
と executeAction
をそれぞれ使って CREATE TABLE を実行してみます。
JdbcContext
における probe
の戻り値は Try[Boolean]
で executeAction
の戻り値は Long
となっています。
sample1/src/main/scala/sample/SampleApp.scala
package sample import io.getquill.{H2JdbcContext, SnakeCase} case class Item(itemId: String, name: String) case class Stock(stockId: String, itemId: String, qty: Int) object SampleApp extends App { lazy val ctx = new H2JdbcContext(SnakeCase, "ctx") import ctx._ // probe を使った CREATE TABLE の実行 val r1 = probe("CREATE TABLE item(item_id VARCHAR(10) PRIMARY KEY, name VARCHAR(10))") println(s"create table1: $r1") // executeAction を使った CREATE TABLE の実行 val r2 = executeAction("CREATE TABLE stock(stock_id VARCHAR(10) PRIMARY KEY, item_id VARCHAR(10), qty INT)") println(s"create table2: $r2") // item への insert println( run(query[Item].insert(lift(Item("item1", "A1")))) ) println( run(query[Item].insert(lift(Item("item2", "B2")))) ) // stock への insert println( run(query[Stock].insert(lift(Stock("stock1", "item1", 5)))) ) println( run(query[Stock].insert(lift(Stock("stock2", "item2", 3)))) ) // item の select println( run(query[Item]) ) // stock の select println( run(query[Stock]) ) // Infix を使った select val selectStocks = quote( infix"""SELECT stock_id AS "_1", name AS "_2", qty AS "_3" FROM stock s join item i on i.item_id = s.item_id""".as[Query[(String, String, Int)]] ) println( run(selectStocks) ) }
実行結果は以下の通りで CREATE TABLE に成功しています。
probe の結果は Success(false)
で executeAction の結果は 0
となりました。
実行結果
> cd sample1 > gradle run ・・・ [main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Starting... [main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Start completed. create table1: Success(false) create table2: 0 1 1 1 1 List(Item(item1,A1), Item(item2,B2)) List(Stock(stock1,item1,5), Stock(stock2,item2,3)) List((stock1,A1,5), (stock2,B2,3)) ・・・
2. モナドの利用
quill には IO
モナドが用意されていたので、これを使って処理を組み立ててみます。
IO は run
の代わりに runIO
を使う事で取得でき、IO の結果は performIO
で取得します。
probe の結果である Try[A]
は IO.fromTry
を使う事で IO にできます。
また、クエリー query[A]
では flatMap 等が使えるので for 内包表記で直接合成できましたが(selectStocks
の箇所)、query[A].insert(・・・)
は flatMap 等を使えなかったので runIO しています。(insertItemAndStock
の箇所)
sample2/src/main/scala/sample/SampleApp.scala
package sample import io.getquill.{H2JdbcContext, SnakeCase} case class Item(itemId: String, name: String) case class Stock(stockId: String, itemId: String, qty: Int) object SampleApp extends App { lazy val ctx = new H2JdbcContext(SnakeCase, "ctx") import ctx._ // CREATE TABLE val createTables = for { it <- probe("CREATE TABLE item(item_id VARCHAR(10) PRIMARY KEY, name VARCHAR(10))") st <- probe("CREATE TABLE stock(stock_id VARCHAR(10) PRIMARY KEY, item_id VARCHAR(10), qty INT)") } yield (it, st) // item と stock へ insert val insertItemAndStock = (itemId: String, name: String, stockId: String, qty: Int) => for { _ <- runIO( query[Item].insert(lift(Item(itemId, name))) ) _ <- runIO( query[Stock].insert(lift(Stock(stockId, itemId, qty))) ) } yield () // stock と item の select(stock と該当する item をタプル化) val selectStocks = quote { for { s <- query[Stock] i <- query[Item] if i.itemId == s.itemId } yield (i, s) } // 処理の合成 val proc = for { r1 <- IO.fromTry(createTables) _ <- insertItemAndStock("item1", "A1", "stock1", 5) _ <- insertItemAndStock("item2", "B2", "stock2", 3) r2 <- runIO(selectStocks) } yield (r1, r2) // 結果 println( performIO(proc) ) // トランザクションを適用する場合は以下のようにする //println( performIO(proc.transactional) ) }
実行結果
> cd sample2 > gradle run ・・・ [main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Starting... [main] INFO com.zaxxer.hikari.HikariDataSource - HikariPool-1 - Start completed. ((false,false),List((Item(item1,A1),Stock(stock1,item1,5)), (Item(item2,B2),Stock(stock2,item2,3)))) ・・・