CNN でランドマーク検出

前回の「CNNで輪郭の検出」 で試した手法を工夫し、ランドマーク(特徴点)検出へ適用してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190217/

輪郭の検出では画像をピクセル単位で二値分類(輪郭以外 = 0, 輪郭 = 1)しましたが、今回はこれを多クラス分類(ランドマーク以外 = 0, ランドマーク1 = 1, ランドマーク2 = 2, ・・・)へ変更します。

ちなみに、Deeplearning でランドマーク検出を行うような場合、ランドマークの座標を直接予測するような手法が考えられますが、今回試してみた限りでは納得のいく結果(座標の精度や汎用性など)を出せなくて、代わりに思いついたのが今回の手法となっています。

はじめに

データセット

今回は、DeepFashion: In-shop Clothes Retrieval のランドマーク用データセットから以下の条件を満たすものに限定して使います。

  • clothes_type の値が 1 (upper-body clothes)
  • variation_type の値が 1 (normal pose)
  • landmark_visibility_1 ~ 6 の値が 0(visible)

ランドマークには 6種類 (landmark_location_x_1 ~ 6、landmark_location_y_1 ~ 6) の座標を使います。

教師データ

入力データには画像を使うため、データ形状は (<バッチサイズ>, 256, 256, 3) ※ となります。

 ※ (<バッチサイズ>, <高さ>, <幅>, <チャンネル数>)

ラベルデータは landmark_location 1 ~ 6 の値を元に動的に生成します。

ピクセル単位でランドマーク以外(= 0)とランドマーク 1 ~ 6 の多クラス分類を行うため、データ形状は (<バッチサイズ>, 256, 256, 7) とします。

ランドマーク毎に 1ピクセルだけランドマークへ分類しても上手く学習できないので ※、一定の大きさ(範囲)をランドマークへ分類する必要があります。

 ※ 全てをランドマーク以外(= 0)とするようになってしまう

そこで、ランドマーク周辺の一定範囲をランドマークへ分類するとともに、以下の図(中心がランドマーク)のようにランドマークから離れると確率値が下がるように工夫します。

f:id:fits:20190217023108p:plain

学習

学習処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、list_landmarks_inshop.txt ファイルを読み込んで必要なデータを抜き出します。

今回は学習時間の短縮のため、先頭から 100件だけを使用しています。

データ読み込みとフィルタリング
import pandas as pd

df = pd.read_table('list_landmarks_inshop.txt', sep = '\s+', skiprows = 1)

s = 100

dfa = df[(df['clothes_type'] == 1) & (df['variation_type'] == 1) &
         (df['landmark_visibility_1'] == 0) & (df['landmark_visibility_2'] == 0) & 
         (df['landmark_visibility_3'] == 0) & (df['landmark_visibility_4'] == 0) &
         (df['landmark_visibility_5'] == 0) & (df['landmark_visibility_6'] == 0)][:s]

次に、入力データとして使う画像を読み込みます。

入力データ(画像)読み込み
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

imgs = np.array([ img_to_array(load_img(f)) for f in dfa['image_name']])

入力データの形状は以下のようになります。

imgs.shape
(100, 256, 256, 3)

(2) ラベルデータの生成

先述したように landmark_location の値から得られたランドマーク座標の周辺に確率値を設定していきます。

ここでは、確率の構成内容や確率値の設定対象とする周辺座標の取得処理を引数で指定できるようにしてみました。

また、他のランドマークの範囲と重なった場合、今回は単純に上書き(後勝ち)するようにしましたが、確率値の大きい方を選択するか確率値を分配するようにした方が望ましいと思われます。

ラベルデータ作成処理
cols = [f'landmark_location_{t}_{i + 1}' for i in range(6) for t in ['x', 'y'] ]
labels_t = dfa[cols].values.astype(int)

def gen_labels(prob, around_func):
    res = np.zeros(imgs.shape[:-1] + (int(len(cols) / 2) + 1,))
    res[:, :, :, 0] = 1.0
    
    for i in range(len(res)):
        r = res[i]
        
        # ランドマーク毎の設定
        for j in range(0, len(labels_t[i]), 2):
            # ランドマークの座標
            x = labels_t[i, j]
            y = labels_t[i, j + 1]
            
            # ランドマークの分類(1 ~ 6)
            c = int(j / 2) + 1
            
            for k in range(len(prob)):
                p = prob[k]
                
                # (相対的な)周辺座標の取得
                for a in around_func(k):
                    ax = x + a[0]
                    ay = y + a[1]
                    
                    if ax >= 0 and ax < imgs.shape[2] and ay >= 0 and ay < imgs.shape[1]:
                        # 他のランドマークと範囲が重なった場合への対応(設定値のクリア)
                        r[ay, ax, :] = 0.0
                        
                        # ランドマーク c へ該当する確率
                        r[ay, ax, c] = p
                        # ランドマーク以外へ該当する確率
                        r[ay, ax, 0] = 1.0 - p

    return res

今回は以下のような内容でラベルデータを作りました。

ラベルデータ作成
def around_square(n):
    return [(x, y) for x in range(-n, n + 1) for y in range(-n, n + 1) if abs(x) == n or abs(y) == n]

labels = gen_labels([1.0, 1.0, 1.0, 0.8, 0.8, 0.7, 0.7, 0.6, 0.6, 0.5], around_square)

ラベルデータの形状は以下の通りです。

labels.shape
(100, 256, 256, 7)

ラベルデータの内容確認

ランドマーク周辺の値を見てみると以下のようになっており、問題無さそうです。

labels[0, 59, 105:126]
array([[1. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0.5, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0.5, 0. , 0. , 0. , 0. , 0. ],
       [1. , 0. , 0. , 0. , 0. , 0. , 0. ]])
labels[0, 50:71, 149]
array([[1. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0. , 0.5, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.5, 0. , 0.5, 0. , 0. , 0. , 0. ],
       [1. , 0. , 0. , 0. , 0. , 0. , 0. ]])

これだけだと分かり難いので、単純な可視化を行ってみます。(ランドマークの該当確率をピクセル毎に合計しているだけ)

ラベルデータの可視化処理
matplotlib inline

import matplotlib.pyplot as plt

def imshow_label(index):
    plt.imshow(labels[index, :, :, 1:].sum(axis = -1), cmap = 'gray')
imshow_label(0)

f:id:fits:20190217023149p:plain

imshow_label(1)

f:id:fits:20190217023204p:plain

特に問題は無さそうです。

(3) CNN モデル

前回 と同様に Encoder-Decoder の構成を採用し、Encoder・Decoder をそれぞれ 1段階深くしました。(4段階に縮小して拡大)

多クラス分類を行うために、出力層の活性化関数を softmax にして、損失関数を categorical_crossentropy としています。

モデル内容
from keras.models import Model
from keras.layers import Input, Dense, Dropout, UpSampling2D
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(256, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(32, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(32, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(16, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(16, 3, padding = 'same', activation = 'relu')(x)

x = Dropout(0.3)(x)

output = Dense(labels.shape[-1], activation = 'softmax')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['acc'])

model.summary()
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_7 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
batch_normalization_46 (Batc (None, 256, 256, 3)       12        
_________________________________________________________________
conv2d_56 (Conv2D)           (None, 256, 256, 16)      448       
_________________________________________________________________
conv2d_57 (Conv2D)           (None, 256, 256, 16)      2320      
_________________________________________________________________
max_pooling2d_20 (MaxPooling (None, 128, 128, 16)      0         
_________________________________________________________________
batch_normalization_47 (Batc (None, 128, 128, 16)      64        
_________________________________________________________________
dropout_46 (Dropout)         (None, 128, 128, 16)      0         
_________________________________________________________________
conv2d_58 (Conv2D)           (None, 128, 128, 32)      4640      
_________________________________________________________________
conv2d_59 (Conv2D)           (None, 128, 128, 32)      9248      
_________________________________________________________________
conv2d_60 (Conv2D)           (None, 128, 128, 32)      9248      
_________________________________________________________________
max_pooling2d_21 (MaxPooling (None, 64, 64, 32)        0         
_________________________________________________________________
batch_normalization_48 (Batc (None, 64, 64, 32)        128       
_________________________________________________________________
dropout_47 (Dropout)         (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_61 (Conv2D)           (None, 64, 64, 64)        18496     
_________________________________________________________________
conv2d_62 (Conv2D)           (None, 64, 64, 64)        36928     
_________________________________________________________________
conv2d_63 (Conv2D)           (None, 64, 64, 64)        36928     
_________________________________________________________________
max_pooling2d_22 (MaxPooling (None, 32, 32, 64)        0         
_________________________________________________________________
batch_normalization_49 (Batc (None, 32, 32, 64)        256       
_________________________________________________________________
dropout_48 (Dropout)         (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_64 (Conv2D)           (None, 32, 32, 128)       73856     
_________________________________________________________________
conv2d_65 (Conv2D)           (None, 32, 32, 128)       147584    
_________________________________________________________________
conv2d_66 (Conv2D)           (None, 32, 32, 128)       147584    
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 16, 16, 128)       0         
_________________________________________________________________
batch_normalization_50 (Batc (None, 16, 16, 128)       512       
_________________________________________________________________
dropout_49 (Dropout)         (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_67 (Conv2D)           (None, 16, 16, 256)       295168    
_________________________________________________________________
batch_normalization_51 (Batc (None, 16, 16, 256)       1024      
_________________________________________________________________
dropout_50 (Dropout)         (None, 16, 16, 256)       0         
_________________________________________________________________
up_sampling2d_20 (UpSampling (None, 32, 32, 256)       0         
_________________________________________________________________
conv2d_transpose_44 (Conv2DT (None, 32, 32, 128)       295040    
_________________________________________________________________
conv2d_transpose_45 (Conv2DT (None, 32, 32, 128)       147584    
_________________________________________________________________
conv2d_transpose_46 (Conv2DT (None, 32, 32, 128)       147584    
_________________________________________________________________
batch_normalization_52 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
dropout_51 (Dropout)         (None, 32, 32, 128)       0         
_________________________________________________________________
up_sampling2d_21 (UpSampling (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_transpose_47 (Conv2DT (None, 64, 64, 64)        73792     
_________________________________________________________________
conv2d_transpose_48 (Conv2DT (None, 64, 64, 64)        36928     
_________________________________________________________________
conv2d_transpose_49 (Conv2DT (None, 64, 64, 64)        36928     
_________________________________________________________________
batch_normalization_53 (Batc (None, 64, 64, 64)        256       
_________________________________________________________________
dropout_52 (Dropout)         (None, 64, 64, 64)        0         
_________________________________________________________________
up_sampling2d_22 (UpSampling (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_transpose_50 (Conv2DT (None, 128, 128, 32)      18464     
_________________________________________________________________
conv2d_transpose_51 (Conv2DT (None, 128, 128, 32)      9248      
_________________________________________________________________
batch_normalization_54 (Batc (None, 128, 128, 32)      128       
_________________________________________________________________
dropout_53 (Dropout)         (None, 128, 128, 32)      0         
_________________________________________________________________
up_sampling2d_23 (UpSampling (None, 256, 256, 32)      0         
_________________________________________________________________
conv2d_transpose_52 (Conv2DT (None, 256, 256, 16)      4624      
_________________________________________________________________
conv2d_transpose_53 (Conv2DT (None, 256, 256, 16)      2320      
_________________________________________________________________
dropout_54 (Dropout)         (None, 256, 256, 16)      0         
_________________________________________________________________
dense_13 (Dense)             (None, 256, 256, 7)       119       
=================================================================
Total params: 1,557,971
Trainable params: 1,556,525
Non-trainable params: 1,446
_________________________________________________________________

(4) 学習

教師データ 100件では少なすぎると思いますが、今回はその中の 80件のみ学習に使用して 20件を検証に使ってみます。(validation_split で指定)

ここで、ランドマークとそれ以外でデータ数に大きな偏りがあるため(ランドマーク以外が大多数)、そのままでは上手く学習できない恐れがあります。

以下では class_weight を使ってランドマーク分類の重みを大きくしています。

実行例(351 ~ 400 エポック)
# 分類毎の重みを定義(ランドマークは 256*256 に設定)
wg = np.ones(labels.shape[-1]) * (imgs.shape[1] * imgs.shape[2])
# ランドマーク以外(= 0)の重み設定
wg[0] = 1

hist = model.fit(imgs, labels, initial_epoch = 350, epochs = 400, batch_size = 10, class_weight = wg, validation_split = 0.2)
結果例
Train on 80 samples, validate on 20 samples
Epoch 351/400
80/80 [===・・・ - loss: 0.0261 - acc: 0.9924 - val_loss: 0.1644 - val_acc: 0.9782
Epoch 352/400
80/80 [===・・・ - loss: 0.0263 - acc: 0.9924 - val_loss: 0.1638 - val_acc: 0.9784
・・・
Epoch 399/400
80/80 [===・・・ - loss: 0.0255 - acc: 0.9930 - val_loss: 0.1719 - val_acc: 0.9775
Epoch 400/400
80/80 [===・・・ - loss: 0.0253 - acc: 0.9931 - val_loss: 0.1720 - val_acc: 0.9777

(5) 確認

fit の戻り値から学習・検証の lossacc の値をそれぞれグラフ化してみます。

fit 結果表示
%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (16, 4)

plt.subplot(1, 4, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 4, 2)
plt.plot(hist.history['acc'])

plt.subplot(1, 4, 3)
plt.plot(hist.history['val_loss'])

plt.subplot(1, 4, 4)
plt.plot(hist.history['val_acc'])
結果例(351 ~ 400 エポック)

f:id:fits:20190217023309p:plain

val_lossval_acc の値が良くないのは、データ量が少なすぎる点にあると考えています。

(6) ランドマーク検出

下記 4種類の画像を出力して、ランドマーク検出(predict)結果とラベルデータ(正解)を比較してみます。

  • (a) ラベルデータの分類(ピクセル毎に確率値が最大の分類で色分け)
  • (b) 予測結果(predict)の分類(ピクセル毎に確率値が最大の分類で色分け)
  • (c) 元画像と (b) の重ね合わせ
  • (d) ランドマークの描画(各ランドマークの確率値が最大の座標へ円を描画)

今回はランドマークは分類毎に 1点のみなので、確率が最大値の座標がランドマークと判断できます。

ランドマーク検出と結果出力
import cv2

colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255), (255, 165, 0), (210, 180, 140)]

def predict(index, n = 0, c_size = 5, s = 5.0):
    plt.rcParams['figure.figsize'] = (s * 4, s)
    
    img = imgs[index]

    # 予測結果(ランドマーク分類結果)
    p = model.predict(np.array([img]))[0]

    # (a) ラベルデータの分類(ピクセル毎に確率値が最大の分類で色分け)
    img1 = np.apply_along_axis(lambda x: colors[x.argmax()], -1, labels[index])
    # (b) 予測結果の分類(ピクセル毎に確率値が最大の分類で色分け)
    img2 = np.apply_along_axis(lambda x: colors[x.argmax()], -1, p)
    # (c) 元画像への重ね合わせ
    img3 = cv2.addWeighted(img.astype(int), 0.4, img2, 0.6, 0)
    
    plt.subplot(1, 4, 1)
    plt.imshow(img1)
    
    plt.subplot(1, 4, 2)
    plt.imshow(img2)
    
    plt.subplot(1, 4, 3)
    plt.imshow(img3)

    img4 = img.astype(int)

    pdf = pd.DataFrame(
        [[np.argmax(vx), x, y, np.max(vx)] for y, vy in enumerate(p) for x, vx in enumerate(vy)], 
        columns = ['landmark', 'x', 'y', 'prob']
    )
    
    for c, v in pdf[pdf['landmark'] > 0].sort_values('prob', ascending = False).groupby('landmark'):
        # (d) ランドマークを描画(確率値が最大の座標へ円を描画)
        img4 = cv2.circle(img4, tuple(v[['x', 'y']].values[0]), c_size, colors[c], -1)
        
        if n > 0:
            print(f"landmark {c} : x = {labels_t[index, (c - 1) * 2]}, {labels_t[index, (c - 1) * 2 + 1]}")
            print(v[:n])

    plt.subplot(1, 4, 4)
    plt.imshow(img4)

学習データの結果例

左から (a) ラベルデータの分類、(b) 予測結果の分類、(c) 元画像との重ね合わせ、(d) ランドマーク検出結果となっています。

f:id:fits:20190217035001p:plain f:id:fits:20190217035116p:plain

ラベルデータにかなり近い結果が出ているように見えます。

下記のように、ランドマーク毎の確率値 TOP 3 とラベルデータを数値で比較してみると、かなり近い値になっている事を確認できました。

predict(0, n = 3)
landmark 1 : x = 115, 59
       landmark    x   y      prob
15475         1  115  60  0.893763
15476         1  116  60  0.893605
15220         1  116  59  0.893044

landmark 2 : x = 149, 60
       landmark    x   y      prob
15510         2  150  60  0.878173
15766         2  150  61  0.872413
15509         2  149  60  0.872222

landmark 3 : x = 82, 153
       landmark   x    y      prob
39250         3  82  153  0.882741
39249         3  81  153  0.881362
39248         3  80  153  0.879979

landmark 4 : x = 185, 150
       landmark    x    y      prob
38841         4  185  151  0.836826
38585         4  185  150  0.836212
38840         4  184  151  0.836164

landmark 5 : x = 93, 198
       landmark   x    y      prob
50782         5  94  198  0.829380
50526         5  94  197  0.825815
51038         5  94  199  0.825342

landmark 6 : x = 171, 197
       landmark    x    y      prob
50602         6  170  197  0.881702
50603         6  171  197  0.880731
50858         6  170  198  0.877772
predict(40, n = 3)
landmark 1 : x = 120, 42
      landmark    x   y      prob
8820         1  116  34  0.568582
9075         1  115  35  0.566257
9074         1  114  35  0.561259

landmark 2 : x = 134, 40
       landmark    x   y      prob
10372         2  132  40  0.812515
10371         2  131  40  0.807980
10628         2  132  41  0.807899

landmark 3 : x = 109, 48
       landmark    x   y      prob
12652         3  108  49  0.839624
12653         3  109  49  0.838190
12396         3  108  48  0.837235

landmark 4 : x = 148, 43
       landmark    x   y      prob
11156         4  148  43  0.837879
10900         4  148  42  0.837810
11157         4  149  43  0.836910

landmark 5 : x = 107, 176
       landmark    x    y      prob
45164         5  108  176  0.845494
45420         5  108  177  0.841054
45163         5  107  176  0.839846

landmark 6 : x = 154, 182
       landmark    x    y      prob
46746         6  154  182  0.865920
46747         6  155  182  0.863970
46490         6  154  181  0.862724

なお、predict(40) におけるランドマーク 1(赤色)の結果が振るわないのは、ラベルデータの作り方の問題だと考えられます。(上書きでは無く確率値が大きい方を採用する等で改善するはず)

検証データの結果例

f:id:fits:20190217035136p:plain f:id:fits:20190217035148p:plain

当然ながら、学習に使っていないこちらのデータでは結果が悪化していますが、それなりに正しそうな位置を部分的に検出しているように見えます。

学習に使ったデータ量の少なさを考えると、かなり良好な結果が出ているようにも思います。

そもそも、predict(-3) のようなランドマークの左右が反転している背面からの画像なんてのは無理があるように思いますし、predict(-8) のランドマーク 5(水色)はラベルデータの方が間違っている(検出結果の方が正しい)ような気もします。

predict(-1, n = 3)
landmark 1 : x = 96, 60
       landmark   x   y      prob
15969         1  97  62  0.872259
16225         1  97  63  0.869837
15970         1  98  62  0.869681

landmark 2 : x = 126, 59
       landmark    x   y      prob
16254         2  126  63  0.866628
16255         2  127  63  0.865502
15998         2  126  62  0.864939

landmark 3 : x = 66, 125
       landmark   x    y      prob
30521         3  57  119  0.832024
30520         3  56  119  0.831721
30777         3  57  120  0.829537

landmark 4 : x = 157, 117
       landmark    x    y      prob
29099         4  171  113  0.814012
29098         4  170  113  0.813680
28843         4  171  112  0.812420
predict(-8, n = 3)
landmark 1 : x = 133, 40
       landmark    x   y      prob
10629         1  133  41  0.812287
10628         1  132  41  0.810564
10373         1  133  40  0.808298

landmark 2 : x = 157, 47
       landmark    x   y      prob
12704         2  160  49  0.767413
12448         2  160  48  0.764577
12703         2  159  49  0.762571

landmark 3 : x = 105, 77
       landmark    x   y     prob
19300         3  100  75  0.79014
19301         3  101  75  0.78945
19556         3  100  76  0.78496

landmark 4 : x = 181, 86
       landmark    x    y      prob
56242         4  178  219  0.768471
55986         4  178  218  0.768215
56243         4  179  219  0.766977

landmark 5 : x = 137, 211
       landmark    x    y      prob
54370         5   98  212  0.710897
54626         5   98  213  0.707652
54372         5  100  212  0.707127

CNN で輪郭の検出

画像内の物体の輪郭検出を CNN(畳み込みニューラルネット)で試してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190114/

概要

今回は、画像をピクセル単位で輪郭か否かに分類する事(輪郭 = 1, 輪郭以外 = 0)で輪郭を検出できないか試しました。

そこで、教師データとして以下のような衣服単体の画像(jpg)と衣服の輪郭部分だけを白く塗りつぶした画像(png)を用意しました。

f:id:fits:20190114225837j:plain

教師データを大量に用意するのは困難だったため、240x288 の画像 160 ファイルで学習を行っています。

学習

学習の処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、入力画像(jpg)を読み込みます。(教師データの画像は img ディレクトリへ配置しています)

import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

files = glob.glob('img/*.jpg')

imgs = np.array([img_to_array(load_img(f)) for f in files])

imgs.shape

入力データの形状は以下の通りです。

imgs.shape 結果
(160, 288, 240, 3)

(2) ラベルデータの準備

輪郭画像(png)を読み込み、128 を境にして二値化(輪郭 = 1、輪郭以外 = 0)します。

import os

th = 128

labels = np.array([img_to_array(load_img(f"{os.path.splitext(f)[0]}.png", color_mode = 'grayscale')) for f in files])

labels[labels < th] = 0
labels[labels >= th] = 1

labels.shape

ラベルデータの形状は以下の通りです。

labels.shape 結果
(160, 288, 240, 1)

(3) CNN モデル

どのようなネットワーク構成が適しているのか分からなかったので、セマンティックセグメンテーション等で用いられている Encoder-Decoder の構成を参考にしてみました。

30x36 まで段階的に縮小して(Encoder)、元の大きさ 240x288 まで段階的に拡大する(Decoder)ようにしています。

最終層の活性化関数に sigmoid を使って 0 ~ 1 の値となるようにしています。

損失関数は binary_crossentropy を使うと進捗が遅そうに見えたので ※、代わりに mean_squared_error を使っています。

 ※ 今回の場合、輪郭(= 1)に該当するピクセルの方が少なくなるため
    binary_crossentropy を使用する場合は fit の class_weight 引数で
    調整する必要があったと思われます

また、参考のため mean_absolute_error(mae) の値も出力するように metrics で指定しています。

from keras.models import Model
from keras.layers import Input, Dropout
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

# Encoder

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

# Decoder

x = Conv2DTranspose(64, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(32, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(16, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

output = Conv2D(1, 1, activation = 'sigmoid')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])

model.summary()
model.summary() 結果
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 288, 240, 3)       0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 288, 240, 3)       12        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 288, 240, 16)      448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 144, 120, 16)      0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 144, 120, 16)      64        
_________________________________________________________________
dropout_1 (Dropout)          (None, 144, 120, 16)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 144, 120, 32)      4640      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 72, 60, 32)        0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 72, 60, 32)        128       
_________________________________________________________________
dropout_2 (Dropout)          (None, 72, 60, 32)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 72, 60, 64)        18496     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 36, 30, 64)        0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 36, 30, 64)        256       
_________________________________________________________________
dropout_3 (Dropout)          (None, 36, 30, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 36, 30, 128)       73856     
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 36, 30, 128)       147584    
_________________________________________________________________
batch_normalization_5 (Batch (None, 36, 30, 128)       512       
_________________________________________________________________
dropout_4 (Dropout)          (None, 36, 30, 128)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 72, 60, 64)        73792     
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
batch_normalization_6 (Batch (None, 72, 60, 64)        256       
_________________________________________________________________
dropout_5 (Dropout)          (None, 72, 60, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 144, 120, 32)      18464     
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
batch_normalization_7 (Batch (None, 144, 120, 32)      128       
_________________________________________________________________
dropout_6 (Dropout)          (None, 144, 120, 32)      0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 288, 240, 16)      4624      
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 288, 240, 16)      2320      
_________________________________________________________________
batch_normalization_8 (Batch (None, 288, 240, 16)      64        
_________________________________________________________________
dropout_7 (Dropout)          (None, 288, 240, 16)      0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 288, 240, 1)       17        
=================================================================
Total params: 576,541
Trainable params: 575,831
Non-trainable params: 710
_________________________________________________________________

(4) 学習

教師データが少ないため、全て学習で使う事にします。

実行例(441 ~ 480 エポック)
hist = model.fit(imgs, labels, initial_epoch = 440, epochs = 480, batch_size = 10)

Keras では fit を繰り返し呼び出すと学習を(続きから)再開できるので、40 エポックを何回か繰り返しました。(バッチサイズは 20 で始めて途中で 10 へ変えたりしています)

その場合、正しいエポックを出力するには initial_epochepochs の値を調整する必要があります ※。

 ※ initial_epoch を指定しなくても
    fit を繰り返し実行するだけで学習は継続されますが、
    その場合は出力されるエポックの値がクリアされます(1 からのカウントとなる)
結果例
Epoch 441/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
Epoch 442/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0125
Epoch 443/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
・・・
Epoch 478/480
160/160 [=====・・・ - loss: 0.0045 - mean_absolute_error: 0.0116
Epoch 479/480
160/160 [=====・・・ - loss: 0.0046 - mean_absolute_error: 0.0117
Epoch 480/480
160/160 [=====・・・ - loss: 0.0044 - mean_absolute_error: 0.0115

(5) 確認

fit の戻り値から mean_squared_error(loss)と mean_absolute_error の値の遷移をグラフ化してみます。

%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (8, 4)

plt.subplot(1, 2, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 2, 2)
plt.plot(hist.history['mean_absolute_error'])
結果例(441 ~ 480 エポック)

f:id:fits:20190114231037p:plain

(6) 検証

(a) 教師データ

教師データの入力画像と輪郭画像(ラベルデータ)、model.predict の結果を並べて表示してみます。

def predict(index, s = 6.0):
    plt.rcParams['figure.figsize'] = (s, s)

    sh = imgs.shape[1:-1]

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([imgs[index]]))[0]
    pred *= 255

    plt.subplot(1, 3, 1)
    # 入力画像の表示
    plt.imshow(imgs[index].astype(int))

    plt.subplot(1, 3, 2)
    # 輪郭画像(ラベルデータ)の表示
    plt.imshow(labels[index].reshape(sh), cmap = 'gray')

    plt.subplot(1, 3, 3)
    # predict の結果表示
    plt.imshow(pred.reshape(sh).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, 輪郭画像, model.predict 結果

f:id:fits:20190114231107j:plain

概ね教師データに近い結果が出るようになっています。

(b) 教師データ以外

教師データとして使っていない画像に対して model.predict を実施し、輪郭の検出を行ってみます。

def predict_eval(file, s = 4.0):
    plt.rcParams['figure.figsize'] = (s, s)

    img = img_to_array(load_img(file))

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    plt.subplot(1, 2, 1)
    # 入力画像の表示
    plt.imshow(img.astype(int))

    plt.subplot(1, 2, 2)
    # predict の結果表示
    plt.imshow(pred.reshape(pred.shape[:-1]).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, model.predict 結果

f:id:fits:20190114231154j:plain

所々で途切れたりしていますが、ある程度の輪郭は検出できているように見えます。

(7) 保存

学習したモデルを保存します。

model.save('model/c1_480.h5')

輪郭検出

学習済みモデルを使って輪郭検出を行う処理をスクリプト化してみました。

predict_contours.py
import sys
import os
import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
from keras.models import load_model
import cv2

model_file = sys.argv[1]
img_files = sys.argv[2]
dest_dir = sys.argv[3]

model = load_model(model_file)

for f in glob.glob(img_files):
    img = img_to_array(load_img(f))

    # 輪郭の検出
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    file, ext = os.path.splitext(os.path.basename(f))

    # 画像の保存
    cv2.imwrite(f"{dest_dir}/{file}_predict.png", pred)

    print(f"done: {f}")
実行例
python predict_contours.py model/c1_480.h5 img_eval2/*.jpg result

480 エポックの学習モデルを使って、教師データに無いタイプの背景を使った画像(影の影響もある)に試してみました。

輪郭検出結果例
入力画像 処理結果
f:id:fits:20181229154654j:plain f:id:fits:20190114231530p:plain
f:id:fits:20181229155049j:plain f:id:fits:20190114231547p:plain

こちらは難しかったようです。

なお、2つ目の画像は 120 エポックの学習モデルの方が良好な結果(輪郭がより多く検出されていた)でした。

MongoDB で条件に合致する子要素を抽出

MongoDB で指定の条件に合致する子要素のみを抽出する方法を調査してみました。

  • MongoDB 4.0.4

はじめに、下記 3つのドキュメントが sample コレクションへ登録されているとします。

ドキュメント内容
{ "_id" : 1, "items" : [
    { "color" : "black", "size" : "S" }, 
    { "color" : "white", "size" : "S" }
] }

{ "_id" : 2, "items" : [
    { "color" : "red",   "size" : "L" }, 
    { "color" : "blue",  "size" : "S" }
] }

{ "_id" : 3, "items" : [
    { "color" : "white", "size" : "L" }, 
    { "color" : "red",   "size" : "L" }, 
    { "color" : "white", "size" : "S" }
] }

ここで、items.colorwhite のものだけを抽出し、以下の結果を得る事を目指してみます。(items の中身が white のものだけを含むようにする)

目標とする検索結果
{ "_id" : 1, "items" : [
    { "color" : "white", "size" : "S" }
] }

{ "_id" : 3, "items" : [
    { "color" : "white", "size" : "L" }, 
    { "color" : "white", "size" : "S" }
] }

(a) items.color で条件指定

まずは {"items.color": "white"} の条件で find した結果です。

white を持つドキュメントだけを抽出できましたが、ドキュメントの内容はそのままなので black 等の余計なものも含んでしまいます。

> db.sample.find({"items.color": "white"})

{ "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

(b) $elemMatch 使用

次に $elemMatch を使ってみます。

$elemMatch を find の query(第一引数)で使うか、projection(第二引数)で使うかで結果が変わります。

query で使う場合は先程の (a) と同じ結果になります。

query で使用
> db.sample.find({"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

query の条件は指定せずに projection で $elemMatch を使った場合の結果は以下です。

全ドキュメントを対象に items の中身がフィルタリングされていますが、条件に合致する全ての子要素が抽出されるわけでは無く、(条件に合致する)先頭の要素しか含まれていません。

projection で使用1
> db.sample.find({}, {"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 2 }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }

query 条件を指定する事で不要なドキュメントを除く事はできますが、条件に合致する全ての子要素が抽出されない事に変わりはありません。

projection で使用2
> db.sample.find({"items.color": "white"}, {"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }

このように、今回のようなドキュメントに対して $elemMatch を使うと、条件に合致する最初の子要素だけが抽出されるようです。(何らかの回避策があるのかもしれませんが)

(c) aggregate 使用

最後に aggregate を使ってみます。

$unwind を使うと配列の個々の要素を処理できるので、$match で white のみに限定した後、$group でグルーピングすれば良さそうです。

> db.sample.aggregate([
  {$unwind: "$items"}, 
  {$match: {"items.color": "white"}}, 
  {$group: {_id: "$_id", "items": {$push: "$items"}}}
])

{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }

これで目指した結果は一応得られました。

なお、対象を最初に絞り込むようにしてソートを付けると以下のようになります。

> db.sample.aggregate([
  {$match: {"items.color": "white"}},
  {$unwind: "$items"},
  {$match: {"items.color": "white"}},
  {$group: {_id: "$_id", "items": {$push: "$items"}}},
  {$sort:  {_id: 1}}
])

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

Scala のケースクラスに制約を持たせる

Scala のケースクラスで値に制約を持たせたい場合にどうするか。

例えば、以下のケースクラスで amount の値を 0 以上となるように制限し、0 未満ならインスタンス化を失敗させる事を考えてみます。

case class Quantity(amount: Int)

使用した環境は以下

ソースは http://github.com/fits/try_samples/tree/master/blog/20181009/

ケースクラスの値を制限

まず、最も単純なのは以下のような実装だと思います。

case class Quantity(amount: Int) {
  if (amount < 0)
    throw new IllegalArgumentException(s"amount($amount) < 0")
}

これだと、例外が throw されてしまい関数プログラミングで扱い難いので Try[Quantity]Option[Quantity] 等を返すようにしたいところです。

そこで、以下のようにケースクラスを abstract 化して、コンパニオンオブジェクトへ生成関数を定義する方法を使ってみました。

sample.scala
import scala.util.{Try, Success, Failure}

sealed abstract case class Quantity private (amount: Int)

object Quantity {
  def apply(amount: Int): Try[Quantity] =
    if (amount >= 0)
      Success(new Quantity(amount){})
    else
      Failure(new IllegalArgumentException(s"amount($amount) < 0"))
}

println(Quantity(1))
println(Quantity(0))
println(Quantity(-1))

// この方法では Quantity へ copy がデフォルト定義されないため
// copy は使えません(error: value copy is not a member of this.Quantity)
//
// println(Quantity(1).map(_.copy(-1)))

実行結果は以下の通りです。

実行結果
> scala sample.scala

Success(Quantity(1))
Success(Quantity(0))
Failure(java.lang.IllegalArgumentException: amount(-1) < 0)

上記 sample.scala では、以下を直接呼び出せないようにしてケースクラスの勝手なインスタンス化を防止しています。

  • (a) コンストラクタ(new)
  • (b) コンパニオンオブジェクトへデフォルト定義される apply
  • (c) ケースクラスへデフォルト定義される copy

そのために、下記 2点を実施しています。

  • (1) コンストラクタの private 化 : (a) の防止
  • (2) ケースクラスの abstract 化 : (b) (c) の防止

(1) コンストラクタの private 化

以下のように private を付ける事でコンストラクタを private 化できます。

コンストラクタの private 化
case class Quantity private (amount: Int)

これで (a) new Quantity(・・・) の実行を防止できますが、以下のように (b) の apply や (c) の copy を実行できてしまいます。

検証例
scala> case class Quantity private (amount: Int)
defined class Quantity

scala> new Quantity(1)
<console>:14: error: constructor Quantity in class Quantity cannot be accessed in object $iw
       new Quantity(1)

scala> Quantity(1)
res1: Quantity = Quantity(1)

scala> Quantity.apply(2)
res2: Quantity = Quantity(2)

scala> Quantity(3).copy(30)
res3: Quantity = Quantity(30)

(2) ケースクラスの abstract 化

ケースクラスを abstract 化すると、通常ならデフォルト定義されるコンパニオンオブジェクトの apply やケースクラスの copy を防止できるようです。

そのため、(1) と組み合わせることで (a) ~ (c) を防止できます。

ケースクラスの abstract 化とコンストラクタの private 化
sealed abstract case class Quantity private (amount: Int)

以下のように Quantity.apply は定義されなくなります。

検証例
scala> sealed abstract case class Quantity private (amount: Int)
defined class Quantity

scala> new Quantity(1){}
<console>:14: error: constructor Quantity in class Quantity cannot be accessed in <$anon: Quantity>
       new Quantity(1){}
           ^

scala> Quantity.apply(1)
<console>:14: error: value apply is not a member of object Quantity
       Quantity.apply(1)

このままだと何もできなくなるため、実際はコンパニオンオブジェクトへ生成用の関数が必要になります。

sealed abstract case class Quantity private (amount: Int)

object Quantity {
  def create(amount: Int): Quantity = new Quantity(amount){}
}

備考

今回の方法は、以下の書籍に記載されているような ADTs(algebraic data types)と Smart constructors をより安全に定義するために活用できると考えています。

Functional and Reactive Domain Modeling

Functional and Reactive Domain Modeling

Kotlin の関数型プログラミング用ライブラリ Λrrow を試してみる

Kotlin で ScalaScalazCats のような関数型プログラミング用のライブラリを探していたところ、以下を見つけたので試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180822/

はじめに

Λrrow は以下のような要素で構成されており、Haskell の型クラス・インスタンスのような仕組みを実現しているようです。

  • Datatypes (Option, Either, Try, Kleisli, StateT, WriterT 等)
  • Typeclasses (Functor, Applicative, Monad, Monoid, Eq, Show 等)
  • Instances (OptionMonadInstance 等)

そして、その実現にはアノテーションプロセッサが活用されているようです。(@higherkind@instance から補助的なコードを生成)

Option のソースコード

例えば、Optionソースコードは以下のようになっていますが、OptionOf の定義は github のソース上には見当たりません。

arrow/core/Option.kt
@higherkind
sealed class Option<out A> : OptionOf<A> {
    ・・・
}

OptionOf はアノテーションプロセッサ ※ で生成したコード(以下)で定義されています。(OptionOf<A>Kind<ForOption, A> の型エイリアス

higherkind.arrow.core.Option.kt
package arrow.core

class ForOption private constructor() { companion object }
typealias OptionOf<A> = arrow.Kind<ForOption, A>

@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <A> OptionOf<A>.fix(): Option<A> =
  this as Option<A>
 ※ @higherkind に対する処理は
    arrow.higherkinds.HigherKindsFileGenerator に実装されているようです

ここで、ForOption というクラスが定義されていますが、この For + 型名 はコンテナ型を表現するための型のようです。

ForOption がある事で Kind<Option<A>, A> ではなく Kind<ForOption, A> と定義できるようになっており、Functor<F>Monad<F> における F の具体型として使われています。(例. Monad<ForOption>

Either のソースコード

次に Eitherソースコードも見てみます。

Either のように型パラメータが複数の場合はアノテーションプロセッサで生成されるコードの内容に多少の違いがあるようです。

arrow/core/Either.kt
@higherkind
sealed class Either<out A, out B> : EitherOf<A, B> {
    ・・・
}

アノテーションプロセッサで生成されたコードは以下のように EitherOf の他に EitherPartialOf が型エイリアスとして定義されています。

higherkind.arrow.core.Either.kt
package arrow.core

class ForEither private constructor() { companion object }
typealias EitherOf<A, B> = arrow.Kind2<ForEither, A, B>
typealias EitherPartialOf<A> = arrow.Kind<ForEither, A>

@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <A, B> EitherOf<A, B>.fix(): Either<A, B> =
  this as Either<A, B>

また、上記で定義されている ForEither クラスとは別に arrow-instances-core モジュールには ForEither 関数 も用意されています。

arrow/instances/either.kt
・・・
class EitherContext<L> : EitherMonadErrorInstance<L>, EitherTraverseInstance<L>, EitherSemigroupKInstance<L> {
  override fun <A, B> Kind<EitherPartialOf<L>, A>.map(f: (A) -> B): Either<L, B> =
    fix().map(f)
}

class EitherContextPartiallyApplied<L> {
  infix fun <A> extensions(f: EitherContext<L>.() -> A): A =
    f(EitherContext())
}

fun <L> ForEither(): EitherContextPartiallyApplied<L> =
  EitherContextPartiallyApplied()

この ForEither 関数は extensions を呼び出す際に使用する事になります。

サンプルコード

Option と Either をそれぞれ使ったサンプルを作成してみます。

(a) Option の利用

Option の SomeNone の作成にはいくつかの方法が用意されています。

Monad の拡張関数として定義されている binding を使うと Haskell の do 記法や Scala の for 式のようにモナドを処理できるようです。(Kotlin の coroutines 機能で実現)

binding の呼び出し方はいくつか考えられますが、ForOption extensions { ・・・ } を使うのが基本のようです。

ForOption extensions へ渡す処理内では this が OptionContext となるため、OptionContext の処理を呼び出せるようになっています。

また、binding へ渡す処理内では MonadContinuation<ForOption, *> が this となります。

bind は MonadContinuation 内で suspend fun <B> Kind<F, B>.bind(): B と定義されており、Kind<F, B> から B の値を取り出す処理となっています。

src/main/kotlin/App.kt
import arrow.core.*
import arrow.instances.*
import arrow.typeclasses.binding

fun main(args: Array<String>) {
    // Some
    val d1: Option<Int> = Option.just(10)
    val d2: Option<Int> = 5.some()
    val d3: Option<Int> = Some(2)
    // None
    val d4: Option<Int> = Option.empty()
    val d5: Option<Int> = none()
    val d6: Option<Int> = None

    // Some(15)
    val r1 = d1.flatMap { a ->
        d2.map { b -> a + b }
    }

    println(r1)

    // Some(17)
    val r2 = Option.monad().binding { // this: MonadContinuation<ForOption, *>
        val a = d1.bind() // 10
        val b = d2.bind() //  5
        val c = d3.bind() //  2
        a + b + c
    }

    println(r2)

    // Some(17)
    val r3 = ForOption extensions { // this: OptionContext
        println(this) // arrow.instances.OptionContext@3ffc5af1

        binding { // this: MonadContinuation<ForOption, *>
            println(this) // arrow.typeclasses.MonadContinuation@26653222

            val a = d1.bind()
            val b = d2.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r3)

    // Some(17)
    val r4 = OptionContext.binding {
        val a = d1.bind()
        val b = d2.bind()
        val c = d3.bind()
        a + b + c
    }

    println(r4)

    // None
    val r5 = Option.monad().binding {
        val a = d1.bind()
        val b = d4.bind()
        a + b
    }

    println(r5)

    // None
    val r6 = ForOption extensions {
        binding {
            val a = d5.bind()
            val b = d2.bind()
            val c = d6.bind()
            a + b + c
        }
    }

    println(r6)
}

Option のような基本的な型を使うだけであれば、依存ライブラリとして arrow-instances-core を指定するだけで問題なさそうです。

そのため、Gradle ビルド定義ファイルは下記のような内容になります。

ここでは The feature "coroutines" is experimental (see: https://kotlinlang.org/docs/diagnostics/experimental-coroutines) 警告ログを出力しないように coroutines を有効化しています。

build.gradle (Gradle ビルド定義)
plugins {
    id 'org.jetbrains.kotlin.jvm' version '1.2.51'
    id 'application'
}

mainClassName = 'AppKt' // 実行するクラス名

repositories {
    jcenter()
}

dependencies {
    compile 'org.jetbrains.kotlin:kotlin-stdlib-jdk8'
    compile 'io.arrow-kt:arrow-instances-core:0.7.3'
}

// coroutines の有効化(ビルド時の警告を抑制)
kotlin {
    experimental {
        coroutines 'enable'
    }
}

実行結果は以下の通りです。

実行結果
> gradle run

・・・
Some(15)
Some(17)
arrow.instances.OptionContext@3ffc5af1
arrow.typeclasses.MonadContinuation@26653222
Some(17)
Some(17)
None
None

(b) Either の利用

Either の場合、extensions の呼び出し方が Option とは多少異なります。

Option の場合は ForOption の extensions を呼び出しましたが、Either の場合は ForEither に extensions は用意されておらず、代わりに EitherContextPartiallyApplied<L> クラス内に定義されています。

EitherContextPartiallyApplied オブジェクトは ForEither 関数 で取得できるので、これを使って ForEither<Left 側の型>() extensions { ・・・ } のようにします。

ForEither<String>() extensions へ渡す処理内の this は EitherContext<String> で、binding へ渡す処理内の this は MonadContinuation<EitherPartialOf<String>, *> となります。

src/main/kotlin/App.kt
import arrow.core.*
import arrow.instances.*
import arrow.typeclasses.binding

fun main(args: Array<String>) {
    // Right
    val d1: Either<String, Int> = Either.right(10)
    val d2: Either<String, Int> = 5.right()
    val d3: Either<String, Int> = Right(2)
    // Left
    val d4: Either<String, Int> = Either.left("error data")

    // Right(b=15)
    val r1 = d1.flatMap { a ->
        d2.map { b -> a + b }
    }

    println(r1)

    // Right(b=17)
    val r2 = Either.monad<String>().binding { // this: MonadContinuation<EitherPartialOf<String>, *>
        val a = d1.bind()
        val b = d2.bind()
        val c = d3.bind()
        a + b + c
    }

    println(r2)

    // Right(b=17)
    // ForEither<String>() 関数の呼び出し
    val r3 = ForEither<String>() extensions { // this: EitherContext<String>
        binding { // this: MonadContinuation<EitherPartialOf<String>, *>
            val a = d1.bind()
            val b = d2.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r3)

    // Left(a=error data)
    val r4 = ForEither<String>() extensions {
        binding {
            val a = d1.bind()
            val b = d4.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r4)
}

(a) と同じ内容の build.gradle を使って実行します。

実行結果
> gradle run

・・・
Right(b=15)
Right(b=17)
Right(b=17)
Left(a=error data)

TypeScript で funfix を使用 - tsc, FuseBox

funfixJavaScript, TypeScript, Flow の関数型プログラミング用ライブラリで、Fantasy LandStatic Land ※ に準拠し Scala の Option, Either, Try, Future 等と同等の型が用意されているようです。

 ※ JavaScript 用に Monoid や Monad 等の代数的構造に関する仕様を定義したもの

今回は Option を使った単純な処理を TypeScript で実装し Node.js で実行してみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180730/

はじめに

Option を使った下記サンプルをコンパイルして実行します。

サンプルソース
import { Option, Some } from 'funfix'

const f = (ma, mb) => ma.flatMap(a => mb.map(b => `${a} + ${b} = ${a + b}`))

const d1 = Some(10)
const d2 = Some(2)

console.log( d1 )

console.log('-----')

console.log( f(d1, d2) )
console.log( f(d1, Option.none()) )

console.log('-----')

console.log( f(d1, d2).getOrElse('none') )
console.log( f(d1, Option.none()).getOrElse('none') )

ビルドと実行

上記ソースファイルを以下の 2通りでビルドして実行してみます。

  • (a) tsc 利用
  • (b) FuseBox 利用

(a) tsc を利用する場合

tsc コマンドを使って TypeScript のソースをコンパイルします。

まずは typescript と funfix モジュールをそれぞれインストールします。

typescript インストール
> npm install --save-dev typescript
funfix インストール
> npm install --save funfix

この状態で sample.ts ファイルをコンパイルしてみると、型関係のエラーが出るものの sample.js は正常に作られました。

コンパイル1
> tsc sample.ts

node_modules/funfix-core/dist/disjunctions.d.ts:775:14 - error TS2416: Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'.
  Type 'undefined' is not assignable to type 'never'.

775     readonly value: undefined;
                 ~~~~~


node_modules/funfix-effect/dist/eval.d.ts:256:42 - error TS2304: Cannot find name 'Iterable'.

256     static sequence<A>(list: Eval<A>[] | Iterable<Eval<A>>): Eval<A[]>;
                                             ~~~~~~~~
・・・

sample.js を実行してみると特に問題無く動作します。

実行1
> node sample.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

これで一応は動いたわけですが、コンパイル時にエラーが出るというのも望ましい状態ではないので、エラー部分を解決してみます。

他にも方法があるかもしれませんが、ここでは以下のように対応します。

  • (1) Property 'value' in type 'TNone' ・・・ 'Option<never>' のエラーに対して tsc 実行時に --strictNullChecks オプションを指定して対応
  • (2) Cannot find name 'Iterable' 等のエラーに対して @types/node をインストールして対応

strictNullChecks は tsc の実行時オプションで指定する以外にも設定ファイル tsconfig.json で設定する事もできるので、ここでは tsconfig.json ファイルを使います。

(1) tsconfig.json
{
  "compilerOptions": {
    "strictNullChecks": true
  }
}

次に @types/node をインストールします。

@types/node には Node.js で実行するための型定義(Node.js 依存の API 等)が TypeScript 用に定義されています。

(2) @types/node インストール
> npm install --save-dev @types/node

この状態で tsc を実行すると先程のようなエラーは出なくなりました。(tsconfig.json を適用するため tsc コマンドへ引数を指定していない点に注意)

コンパイル2
> tsc

実行結果にも差異はありません。

実行2
> node sample.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

最終的な package.json の内容は以下の通りです。

package.json
{
  "name": "sample",
  "version": "1.0.0",
  "devDependencies": {
    "@types/node": "^10.5.4",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

(b) FuseBox を利用する場合

次に、モジュールバンドラーの FuseBox を使用してみます。(以降は (a) とは異なるディレクトリで実施)

なお、ここでは npm の代わりに yarn を使っていますが、npm でも特に問題はありません。

yarn のインストール例(npm 使用)
> npm install -g yarn

(b-1) 型チェック無し

まずは typescript, fuse-box, funfix をそれぞれインストールしておきます。

typescript と fuse-box インストール
> yarn add --dev typescript fuse-box
funfix インストール
> yarn add funfix

FuseBox ではビルド定義を JavaScript のコードで記載します。 とりあえずは必要最小限の設定を行いました。

bundle で指定した名称が init の $name に適用されるため、*.tsコンパイル結果と依存モジュールの内容をバンドルして bundle.js へ出力する事になります。

なお、> でロード時に実行する(コードを記載した)ファイルを指定します。

fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box')

const fuse = FuseBox.init({
    output: '$name.js'
})

fuse.bundle('bundle').instructions('> *.ts')

fuse.run()

上記スクリプトを実行してビルド(TypeScript のコンパイルとバンドル)を行います。

ビルド
> node fuse.js

--- FuseBox 3.4.0 ---
  → Generating recommended tsconfig.json:  ・・・\sample_fusebox1\tsconfig.json
  → Typescript script target: ES7

--------------------------
Bundle "bundle"

    sample.js
└──  (1 files,  700 Bytes) default
└── funfix-core 34.4 kB (1 files)
└── funfix-effect 43.1 kB (1 files)
└── funfix-exec 79.5 kB (1 files)
└── funfix 1 kB (1 files)
size: 158.7 kB in 765ms

初回実行時にデフォルト設定の tsconfig.json が作られました。(tsconfig.json が存在しない場合)

tsc の時のような型関係のエラーは出ていませんが、これは FuseBox がデフォルトで TypeScript の型チェックをしていない事が原因のようです。

型チェックを実施するには fuse-box-typechecker プラグインを使う必要がありそうです。

実行
> node bundle.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

package.json の内容は以下の通りです。

package.json
{
  "name": "sample_fusebox1",
  "version": "1.0.0",
  "main": "bundle.js",
  "license": "MIT",
  "devDependencies": {
    "fuse-box": "^3.4.0",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

(b-2) 型チェック有り

TypeScript の型チェックを行うようにしてみます。

まずは、(b-1) と同じ構成に fuse-box-typechecker プラグインを加えます。

fuse-box-typechecker を追加インストール
> yarn add --dev fuse-box-typechecker

次に、fuse.js へ fuse-box-typechecker プラグインの設定を追加します。

TypeChecker で型チェックにエラーがあった場合、例外が throw されるようにはなっていないため、ここではエラーがあった場合に Error を throw して fuse.run() を実行しないようにしてみました。

ただし、こうすると tsconfig.json を予め用意しておく必要があります。(TypeChecker に tsconfig.json が必要)

fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box')
const {TypeChecker} = require('fuse-box-typechecker')

// fuse-box-typechecker の設定
const testSync = TypeChecker({
    tsConfig: './tsconfig.json'
})

const fuse = FuseBox.init({
    output: '$name.js'
})

fuse.bundle('bundle').instructions('> *.ts')

testSync.runPromise()
    .then(n => {
        if (n != 0) {
            // 型チェックでエラーがあった場合
            throw new Error(n)
        }
        // 型チェックでエラーがなかった場合
        return fuse.run()
    })
    .catch(console.error)

これで、ビルド時に (a) と同様の型エラーが出るようになりました。

ビルド1
> node fuse.js

・・・
--- FuseBox 3.4.0 ---

Typechecker plugin(promisesync) .
Time:Sun Jul 29 2018 12:40:47 GMT+0900 (GMT+09:00)

File errors:
└── .\node_modules\funfix-core\dist\disjunctions.d.ts
   | ・・・\sample_fusebox2\node_modules\funfix-core\dist\disjunctions.d.ts (775,14) (Error:TS2416) Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'.
  Type 'undefined' is not assignable to type 'never'.

Errors:1
└── Options: 0
└── Global: 0
└── Syntactic: 0
└── Semantic: 1
└── TsLint: 0

Typechecking time: 4116ms
Quitting typechecker

・・・

ここで、Iterable の型エラーが出ていないのは fuse-box-typechecker のインストール時に @types/node もインストールされているためです。

(a) と同様に strictNullChecks の設定を tsconfig.json追記して、このエラーを解決します。

tsconfig.json へ strictNullChecks の設定を追加
{
  "compilerOptions": {
    "module": "commonjs",
    "target": "ES7",
    ・・・
    "strictNullChecks": true
  }
}

これでビルドが成功するようになりました。

ビルド2
> node fuse.js

・・・
Typechecker name: undefined
Typechecker basepath: ・・・\sample_fusebox2
Typechecker tsconfig: ・・・\sample_fusebox2\tsconfig.json
--- FuseBox 3.4.0 ---

Typechecker plugin(promisesync) .
Time:Sun Jul 29 2018 12:44:57 GMT+0900 (GMT+09:00)
All good, no errors :-)
Typechecking time: 4103ms
Quitting typechecker

killing worker  → Typescript config file:  \tsconfig.json
  → Typescript script target: ES7

--------------------------
Bundle "bundle"

    sample.js
└──  (1 files,  700 Bytes) default
└── funfix-core 34.4 kB (1 files)
└── funfix-effect 43.1 kB (1 files)
└── funfix-exec 79.5 kB (1 files)
└── funfix 1 kB (1 files)
size: 158.7 kB in 664ms
実行結果
> node bundle.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

package.json の内容は以下の通りです。

package.json
{
  "name": "sample_fusebox2",
  "version": "1.0.0",
  "main": "index.js",
  "license": "MIT",
  "devDependencies": {
    "fuse-box": "^3.4.0",
    "fuse-box-typechecker": "^2.10.0",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

Word2Vec を用いた併売の分析 - gensim

トピックモデルを用いた併売の分析」ではトピックモデルによる併売の分析を試しましたが、今回は gensim の Word2Vec で試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180617/

はじめに

データセット

これまで は適当に作ったデータセットを使っていましたが、今回は R の Groceries データセット ※ をスペース区切りのテキストファイル(groceries.txt)にして使います。(商品名にスペースを含む場合は代わりに _ を使っています)

 ※ ある食料雑貨店における 30日間の POS データ
groceries.txt
citrus_fruit semi-finished_bread margarine ready_soups
tropical_fruit yogurt coffee
whole_milk
pip_fruit yogurt cream_cheese_ meat_spreads
other_vegetables whole_milk condensed_milk long_life_bakery_product
whole_milk butter yogurt rice abrasive_cleaner
rolls/buns
other_vegetables UHT-milk rolls/buns bottled_beer liquor_(appetizer)
pot_plants
whole_milk cereals
・・・
cooking_chocolate
chicken citrus_fruit other_vegetables butter yogurt frozen_dessert domestic_eggs rolls/buns rum cling_film/bags
semi-finished_bread bottled_water soda bottled_beer
chicken tropical_fruit other_vegetables vinegar shopping_bags

R によるアソシエーションルールの抽出結果

参考のため、まずは R を使って Groceries データセットapriori で処理しました。

リフト値を優先するため、支持度 (supp) と確信度 (conf) を低めの値にしています。

groceries_apriori.R
library(arules)
data(Groceries)

params <- list(supp = 0.001, conf = 0.1)

rules <- apriori(Groceries, parameter = params)

inspect(head(sort(rules, by = "lift"), 10))
実行結果
> Rscript groceries_apriori.R

・・・
     lhs                        rhs                         support confidence     lift count
[1]  {bottled beer,                                                                          
      red/blush wine}        => {liquor}                0.001931876  0.3958333 35.71579    19
[2]  {hamburger meat,                                                                        
      soda}                  => {Instant food products} 0.001220132  0.2105263 26.20919    12
[3]  {ham,                                                                                   
      white bread}           => {processed cheese}      0.001931876  0.3800000 22.92822    19
[4]  {root vegetables,                                                                       
      other vegetables,                                                                      
      whole milk,                                                                            
      yogurt}                => {rice}                  0.001321810  0.1688312 22.13939    13
[5]  {bottled beer,                                                                          
      liquor}                => {red/blush wine}        0.001931876  0.4130435 21.49356    19
[6]  {Instant food products,                                                                 
      soda}                  => {hamburger meat}        0.001220132  0.6315789 18.99565    12
[7]  {curd,                                                                                  
      sugar}                 => {flour}                 0.001118454  0.3235294 18.60767    11
[8]  {soda,                                                                                  
      salty snack}           => {popcorn}               0.001220132  0.1304348 18.06797    12
[9]  {sugar,                                                                                 
      baking powder}         => {flour}                 0.001016777  0.3125000 17.97332    10
[10] {processed cheese,                                                                      
      white bread}           => {ham}                   0.001931876  0.4634146 17.80345    19

bottled beer と red/blush wine で liquor が同時に買われやすい、hamburger meat と soda で Instant food products が同時に買われやすいという結果(アソシエーションルール)が出ています。

Word2Vec の適用

それでは groceries.txt を gensim の Word2Vec で処理してみます。 とりあえず iter を 500 に min_count を 1 にしてみました。

なお、購入品目の多い POS データを処理する場合は window パラメータを大きめにすべきかもしれません。(今回はデフォルト値の 5)

今回は Jupyter Notebook で実行しています。

Word2Vec モデルの構築
from gensim.models import word2vec

sentences = word2vec.LineSentence('groceries.txt')

model = word2vec.Word2Vec(sentences, iter = 500, min_count = 1)

類似品の算出

まずは、wv.most_similar で類似単語(商品)を抽出してみます。

pork の類似単語
model.wv.most_similar('pork')
[('turkey', 0.5547687411308289),
 ('ham', 0.49448296427726746),
 ('pip_fruit', 0.46879759430885315),
 ('tropical_fruit', 0.4383287727832794),
 ('butter', 0.43373265862464905),
 ('frankfurter', 0.4334157109260559),
 ('root_vegetables', 0.4249211549758911),
 ('citrus_fruit', 0.4246293306350708),
 ('chicken', 0.42378148436546326),
 ('sausage', 0.41153857111930847)]

微妙なものも含んでいますが、それなりの結果になっているような気もします。

most_similar はベクトル的に類似している単語を抽出するため、POS データを処理する場合は競合や代用品の抽出に使えるのではないかと思います。

併売の分析

併売の商品はお互いに類似していないと思うので most_similar は役立ちそうにありませんが、それでも何らかの関係性はありそうな気がします。

そこで、指定した単語群の中心となる単語を抽出する predict_output_word を使えないかと思い、R で抽出したアソシエーションルールの組み合わせで試してみました。

predict_output_word の検証

bottled_beer と red/blush_wine
model.predict_output_word(['bottled_beer', 'red/blush_wine'])
[('liquor', 0.22384332),
 ('prosecco', 0.04933687),
 ('sparkling_wine', 0.0345262),
 ・・・]

R の結果に出てた liquor が先頭(確率が最大)に来ています。

bottled_beer と red/blush_wine
model.predict_output_word(['hamburger_meat', 'soda'])
[('Instant_food_products', 0.054281656),
 ('canned_vegetables', 0.029985178),
 ('pasta', 0.025487985),
 ・・・]

ここでも R の結果に出てた Instant_food_products が先頭に来ています。

ham と white_bread
model.predict_output_word(['ham', 'white_bread'])
[('processed_cheese', 0.20990367),
 ('sweet_spreads', 0.024131883),
 ('spread_cheese', 0.023222428),
 ・・・]

こちらも同様です。

root_vegetables と other_vegetables と whole_milk と yogurt
model.predict_output_word(['root_vegetables', 'other_vegetables', 'whole_milk', 'yogurt'])
[('herbs', 0.024541182),
 ('liver_loaf', 0.019327056),
 ('turkey', 0.01775743),
 ('onions', 0.01760579),
 ('specialty_cheese', 0.014991459),
 ('packaged_fruit/vegetables', 0.014529809),
 ('spread_cheese', 0.012931713),
 ('meat', 0.012434797),
 ('beef', 0.011924307),
 ('butter_milk', 0.011828974)]

R の結果にあった rice はこの中には含まれていません。

curd と sugar
model.predict_output_word(['curd', 'sugar'])
[('flour', 0.076272935),
 ('pudding_powder', 0.055790607),
 ('baking_powder', 0.026003197),
 ・・・]

R の結果に出てた flour (小麦粉) が先頭に来ています。

soda と salty_snack
model.predict_output_word(['soda', 'salty_snack'])
[('popcorn', 0.05830234),
 ('nut_snack', 0.046429735),
 ('chewing_gum', 0.0213278),
 ・・・]

こちらも同様です。

sugar と baking_powder
model.predict_output_word(['sugar', 'baking_powder'])
[('flour', 0.11954326),
 ('cooking_chocolate', 0.046284538),
 ('pudding_powder', 0.03714784),
 ・・・]

こちらも同様です。

以上のように、少なくとも 2品を指定した場合の predict_output_word の結果は R で抽出したアソシエーションルールに合致しているようです。

Word2Vec のパラメータに左右されるのかもしれませんが、この結果を見る限りでは predict_output_word を 3品の併売の組み合わせ抽出に使えるかもしれない事が分かりました。

3品の併売

次に predict_output_word で 2品に対する 1品を確率の高い順に抽出してみました。

なお、ここでは 3品の組み合わせの購入数が 10 未満のものは除外するようにしています。

from collections import Counter
import itertools

# 3品の組み合わせのカウント
tri_counter = Counter([c for ws in sentences for c in itertools.combinations(sorted(ws), 3)])

# 2品の組み合わせを作成
pairs = itertools.combinations(model.wv.vocab.keys(), 2)

sorted([
    (p, item, prob) for p in pairs for item, prob in model.predict_output_word(p)
    if prob >= 0.05 and tri_counter[tuple(sorted([p[0], p[1], item]))] >= 10
], key = lambda x: -x[2])
[(('bottled_beer', 'red/blush_wine'), 'liquor', 0.22384332),
 (('white_bread', 'ham'), 'processed_cheese', 0.20990367),
 (('bottled_beer', 'liquor'), 'red/blush_wine', 0.16274776),
 (('sugar', 'baking_powder'), 'flour', 0.11954326),
 (('curd', 'sugar'), 'flour', 0.076272935),
 (('margarine', 'sugar'), 'flour', 0.07422828),
 (('flour', 'sugar'), 'baking_powder', 0.07345509),
 (('sugar', 'whipped/sour_cream'), 'flour', 0.072731614),
 (('rolls/buns', 'hamburger_meat'), 'Instant_food_products', 0.06818052),
 (('sugar', 'root_vegetables'), 'flour', 0.0641469),
 (('tropical_fruit', 'white_bread'), 'processed_cheese', 0.061861355),
 (('soda', 'ham'), 'processed_cheese', 0.06138085),
 (('white_bread', 'processed_cheese'), 'ham', 0.061199907),
 (('whole_milk', 'ham'), 'processed_cheese', 0.059773713),
 (('beef', 'root_vegetables'), 'herbs', 0.059243686),
 (('sugar', 'whipped/sour_cream'), 'baking_powder', 0.05871357),
 (('soda', 'salty_snack'), 'popcorn', 0.05830234),
 (('soda', 'popcorn'), 'salty_snack', 0.05819882),
 (('red/blush_wine', 'liquor'), 'bottled_beer', 0.057226427),
 (('flour', 'baking_powder'), 'sugar', 0.05517209),
 (('soda', 'hamburger_meat'), 'Instant_food_products', 0.054281656),
 (('processed_cheese', 'ham'), 'white_bread', 0.053193364),
 (('other_vegetables', 'ham'), 'processed_cheese', 0.052585844)]

R で抽出したアソシエーションルールと同じ様な結果が出ており、それなりの結果が出ているように思います。

skip-gram の場合

gensim の Word2Vec はデフォルトで CBoW を使うようですので、skip-gram の場合にどうなるかも簡単に確認してみました。

skip-gram の使用
model = word2vec.Word2Vec(sentences, iter = 500, min_count = 1, sg = 1)

まずは predict_output_word の結果をいくつか見てみます。

先頭(確率が最大のもの)は変わらないようですが、CBoW よりも確率の値が全体的に低くなっているようです。

bottled_beer と red/blush_wine
model.predict_output_word(['bottled_beer', 'red/blush_wine'])
[('liquor', 0.076620705),
 ('prosecco', 0.030791236),
 ('liquor_(appetizer)', 0.027123762),
 ・・・]
hamburger_meat と soda
model.predict_output_word(['hamburger_meat', 'soda'])
[('Instant_food_products', 0.022627866),
 ('pasta', 0.018009944),
 ('canned_vegetables', 0.01685342),
 ・・・]
root_vegetables と other_vegetables と whole_milk と yogurt
model.predict_output_word(['root_vegetables', 'other_vegetables', 'whole_milk', 'yogurt'])
[('herbs', 0.015105391),
 ('turkey', 0.014365919),
 ('rice', 0.01316431),
 ・・・]

ここでは、CBoW で 10番以内に入っていなかった rice が入っています。

次に、先程と同様に predict_output_word で 3品の組み合わせを確率順に抽出してみます。

確率の値が全体的に下がっているため、最小値の条件を 0.02 へ変えています。

predict_output_word を使った 3品の組み合わせ抽出
・・・
sorted([
    (p, item, prob) for p in pairs for item, prob in model.predict_output_word(p)
    if prob >= 0.02 and tri_counter[tuple(sorted([p[0], p[1], item]))] >= 10
], key = lambda x: -x[2])
[(('bottled_beer', 'red/blush_wine'), 'liquor', 0.076620705),
 (('bottled_beer', 'liquor'), 'red/blush_wine', 0.0712179),
 (('white_bread', 'ham'), 'processed_cheese', 0.039820198),
 (('red/blush_wine', 'liquor'), 'bottled_beer', 0.031292748),
 (('sugar', 'baking_powder'), 'flour', 0.030803043),
 (('sugar', 'whipped/sour_cream'), 'flour', 0.029322423),
 (('margarine', 'sugar'), 'flour', 0.027827),
 (('beef', 'root_vegetables'), 'herbs', 0.02740662),
 (('curd', 'sugar'), 'flour', 0.025570681),
 (('flour', 'sugar'), 'baking_powder', 0.025403246),
 (('tropical_fruit', 'root_vegetables'), 'turkey', 0.025329975),
 (('whole_milk', 'ham'), 'processed_cheese', 0.024535457),
 (('rolls/buns', 'hamburger_meat'), 'Instant_food_products', 0.02427808),
 (('flour', 'baking_powder'), 'sugar', 0.023779714),
 (('tropical_fruit', 'white_bread'), 'processed_cheese', 0.023528077),
 (('sugar', 'root_vegetables'), 'flour', 0.023394365),
 (('soda', 'salty_snack'), 'popcorn', 0.02322538),
 (('whole_milk', 'sugar'), 'flour', 0.023202542),
 (('fruit/vegetable_juice', 'ham'), 'processed_cheese', 0.023127634),
 (('butter', 'root_vegetables'), 'herbs', 0.02304014),
 (('soda', 'ham'), 'processed_cheese', 0.022633638),
 (('soda', 'hamburger_meat'), 'Instant_food_products', 0.022627866),
 (('citrus_fruit', 'sugar'), 'flour', 0.022040429),
 (('bottled_beer', 'soda'), 'liquor', 0.02189085),
 (('processed_cheese', 'ham'), 'white_bread', 0.021692872),
 (('yogurt', 'sugar'), 'flour', 0.021522585),
 (('tropical_fruit', 'other_vegetables'), 'turkey', 0.021456005),
 (('other_vegetables', 'beef'), 'herbs', 0.021407435),
 (('white_bread', 'processed_cheese'), 'ham', 0.021362728),
 (('curd', 'root_vegetables'), 'herbs', 0.021005861),
 (('other_vegetables', 'ham'), 'processed_cheese', 0.020965746),
 (('root_vegetables', 'whipped/sour_cream'), 'herbs', 0.020788824),
 (('other_vegetables', 'root_vegetables'), 'herbs', 0.020782541),
 (('sugar', 'whipped/sour_cream'), 'baking_powder', 0.02058014),
 (('whole_milk', 'sugar'), 'rice', 0.020371588),
 (('root_vegetables', 'frozen_vegetables'), 'herbs', 0.02027719),
 (('whole_milk', 'Instant_food_products'), 'hamburger_meat', 0.020258738),
 (('citrus_fruit', 'root_vegetables'), 'herbs', 0.020241175)]

最小値の条件を下げたために、より多くの組み合わせを抽出していますが、CBoW の結果と大きな違いは無さそうです。