CNN で輪郭の検出

画像内の物体の輪郭検出を CNN(畳み込みニューラルネット)で試してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190114/

概要

今回は、画像をピクセル単位で輪郭か否かに分類する事(輪郭 = 1, 輪郭以外 = 0)で輪郭を検出できないか試しました。

そこで、教師データとして以下のような衣服単体の画像(jpg)と衣服の輪郭部分だけを白く塗りつぶした画像(png)を用意しました。

f:id:fits:20190114225837j:plain

教師データを大量に用意するのは困難だったため、240x288 の画像 160 ファイルで学習を行っています。

学習

学習の処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、入力画像(jpg)を読み込みます。(教師データの画像は img ディレクトリへ配置しています)

import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

files = glob.glob('img/*.jpg')

imgs = np.array([img_to_array(load_img(f)) for f in files])

imgs.shape

入力データの形状は以下の通りです。

imgs.shape 結果
(160, 288, 240, 3)

(2) ラベルデータの準備

輪郭画像(png)を読み込み、128 を境にして二値化(輪郭 = 1、輪郭以外 = 0)します。

import os

th = 128

labels = np.array([img_to_array(load_img(f"{os.path.splitext(f)[0]}.png", color_mode = 'grayscale')) for f in files])

labels[labels < th] = 0
labels[labels >= th] = 1

labels.shape

ラベルデータの形状は以下の通りです。

labels.shape 結果
(160, 288, 240, 1)

(3) CNN モデル

どのようなネットワーク構成が適しているのか分からなかったので、セマンティックセグメンテーション等で用いられている Encoder-Decoder の構成を参考にしてみました。

30x36 まで段階的に縮小して(Encoder)、元の大きさ 240x288 まで段階的に拡大する(Decoder)ようにしています。

最終層の活性化関数に sigmoid を使って 0 ~ 1 の値となるようにしています。

損失関数は binary_crossentropy を使うと進捗が遅そうに見えたので ※、代わりに mean_squared_error を使っています。

 ※ 今回の場合、輪郭(= 1)に該当するピクセルの方が少なくなるため
    binary_crossentropy を使用する場合は fit の class_weight 引数で
    調整する必要があったと思われます

また、参考のため mean_absolute_error(mae) の値も出力するように metrics で指定しています。

from keras.models import Model
from keras.layers import Input, Dropout
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

# Encoder

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

# Decoder

x = Conv2DTranspose(64, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(32, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(16, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

output = Conv2D(1, 1, activation = 'sigmoid')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])

model.summary()
model.summary() 結果
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 288, 240, 3)       0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 288, 240, 3)       12        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 288, 240, 16)      448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 144, 120, 16)      0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 144, 120, 16)      64        
_________________________________________________________________
dropout_1 (Dropout)          (None, 144, 120, 16)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 144, 120, 32)      4640      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 72, 60, 32)        0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 72, 60, 32)        128       
_________________________________________________________________
dropout_2 (Dropout)          (None, 72, 60, 32)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 72, 60, 64)        18496     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 36, 30, 64)        0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 36, 30, 64)        256       
_________________________________________________________________
dropout_3 (Dropout)          (None, 36, 30, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 36, 30, 128)       73856     
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 36, 30, 128)       147584    
_________________________________________________________________
batch_normalization_5 (Batch (None, 36, 30, 128)       512       
_________________________________________________________________
dropout_4 (Dropout)          (None, 36, 30, 128)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 72, 60, 64)        73792     
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
batch_normalization_6 (Batch (None, 72, 60, 64)        256       
_________________________________________________________________
dropout_5 (Dropout)          (None, 72, 60, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 144, 120, 32)      18464     
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
batch_normalization_7 (Batch (None, 144, 120, 32)      128       
_________________________________________________________________
dropout_6 (Dropout)          (None, 144, 120, 32)      0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 288, 240, 16)      4624      
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 288, 240, 16)      2320      
_________________________________________________________________
batch_normalization_8 (Batch (None, 288, 240, 16)      64        
_________________________________________________________________
dropout_7 (Dropout)          (None, 288, 240, 16)      0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 288, 240, 1)       17        
=================================================================
Total params: 576,541
Trainable params: 575,831
Non-trainable params: 710
_________________________________________________________________

(4) 学習

教師データが少ないため、全て学習で使う事にします。

実行例(441 ~ 480 エポック)
hist = model.fit(imgs, labels, initial_epoch = 440, epochs = 480, batch_size = 10)

Keras では fit を繰り返し呼び出すと学習を(続きから)再開できるので、40 エポックを何回か繰り返しました。(バッチサイズは 20 で始めて途中で 10 へ変えたりしています)

その場合、正しいエポックを出力するには initial_epochepochs の値を調整する必要があります ※。

 ※ initial_epoch を指定しなくても
    fit を繰り返し実行するだけで学習は継続されますが、
    その場合は出力されるエポックの値がクリアされます(1 からのカウントとなる)
結果例
Epoch 441/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
Epoch 442/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0125
Epoch 443/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
・・・
Epoch 478/480
160/160 [=====・・・ - loss: 0.0045 - mean_absolute_error: 0.0116
Epoch 479/480
160/160 [=====・・・ - loss: 0.0046 - mean_absolute_error: 0.0117
Epoch 480/480
160/160 [=====・・・ - loss: 0.0044 - mean_absolute_error: 0.0115

(5) 確認

fit の戻り値から mean_squared_error(loss)と mean_absolute_error の値の遷移をグラフ化してみます。

%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (8, 4)

plt.subplot(1, 2, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 2, 2)
plt.plot(hist.history['mean_absolute_error'])
結果例(441 ~ 480 エポック)

f:id:fits:20190114231037p:plain

(6) 検証

(a) 教師データ

教師データの入力画像と輪郭画像(ラベルデータ)、model.predict の結果を並べて表示してみます。

def predict(index, s = 6.0):
    plt.rcParams['figure.figsize'] = (s, s)

    sh = imgs.shape[1:-1]

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([imgs[index]]))[0]
    pred *= 255

    plt.subplot(1, 3, 1)
    # 入力画像の表示
    plt.imshow(imgs[index].astype(int))

    plt.subplot(1, 3, 2)
    # 輪郭画像(ラベルデータ)の表示
    plt.imshow(labels[index].reshape(sh), cmap = 'gray')

    plt.subplot(1, 3, 3)
    # predict の結果表示
    plt.imshow(pred.reshape(sh).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, 輪郭画像, model.predict 結果

f:id:fits:20190114231107j:plain

概ね教師データに近い結果が出るようになっています。

(b) 教師データ以外

教師データとして使っていない画像に対して model.predict を実施し、輪郭の検出を行ってみます。

def predict_eval(file, s = 4.0):
    plt.rcParams['figure.figsize'] = (s, s)

    img = img_to_array(load_img(file))

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    plt.subplot(1, 2, 1)
    # 入力画像の表示
    plt.imshow(img.astype(int))

    plt.subplot(1, 2, 2)
    # predict の結果表示
    plt.imshow(pred.reshape(pred.shape[:-1]).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, model.predict 結果

f:id:fits:20190114231154j:plain

所々で途切れたりしていますが、ある程度の輪郭は検出できているように見えます。

(7) 保存

学習したモデルを保存します。

model.save('model/c1_480.h5')

輪郭検出

学習済みモデルを使って輪郭検出を行う処理をスクリプト化してみました。

predict_contours.py
import sys
import os
import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
from keras.models import load_model
import cv2

model_file = sys.argv[1]
img_files = sys.argv[2]
dest_dir = sys.argv[3]

model = load_model(model_file)

for f in glob.glob(img_files):
    img = img_to_array(load_img(f))

    # 輪郭の検出
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    file, ext = os.path.splitext(os.path.basename(f))

    # 画像の保存
    cv2.imwrite(f"{dest_dir}/{file}_predict.png", pred)

    print(f"done: {f}")
実行例
python predict_contours.py model/c1_480.h5 img_eval2/*.jpg result

480 エポックの学習モデルを使って、教師データに無いタイプの背景を使った画像(影の影響もある)に試してみました。

輪郭検出結果例
入力画像 処理結果
f:id:fits:20181229154654j:plain f:id:fits:20190114231530p:plain
f:id:fits:20181229155049j:plain f:id:fits:20190114231547p:plain

こちらは難しかったようです。

なお、2つ目の画像は 120 エポックの学習モデルの方が良好な結果(輪郭がより多く検出されていた)でした。