IDWR データで再帰型ニューラルネットワーク - Keras

前回 加工した IDWR データ を使って再帰ニューラルネットワーク(RNN)を Keras + TensorFlow で試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180121/

はじめに

インストール

tensorflow と keras をインストールしておきます。

インストール例
> pip install tensorflow
> pip install keras

データセット

データセット前回 に作成した、2014年 1週目 ~ 2017年 49週目 (2015年は 53週目まで) の teiten.csv を 1つの csv ファイル (idwr.csv) へ連結したものを使います。

再帰ニューラルネットワーク(RNN)

指定の感染症に対する週別の全国合計値を再帰ニューラルネットワーク(RNN)で学習・予測する処理を実装してみます。

(a) 4週分のデータを使用

とりあえず、任意の週の過去 4週の報告数を入力データ、その週の報告数をラベルデータとして学習を実施してみます。

ただ、入力データに季節性(周期性)の指標となるデータ(ここでは週)を含んでいないので、季節性のあるデータに対する結果は期待できないように思います。

データ加工

入力データとラベルデータは、以下の要領で 1週分ずつスライドしたものを用意します。

  • 2014年 1 ~ 4週目を入力データとし 5週目をラベルデータ
  • 2014年 2 ~ 5週目を入力データとし 6週目をラベルデータ
入力データ・ラベルデータの作成例(インフルエンザ)
import pandas as pd
import numpy as np

t = 4

df = pd.read_csv('idwr.csv', encoding = 'UTF-8')

ds = df.groupby(['year', 'week'])['インフルエンザ'].sum()

def window(d, wsize):
    return [d[i:(i + wsize)].values.flatten() for i in range(len(d) - wsize + 1)]

# t + 1 週分のスライドデータを作成
dw = window(ds.astype('float'), t + 1)

# 入力データ(t週分)
data = np.array([i[0:-1] for i in dw]).reshape(len(dw), t, 1)
# ラベルデータ(残り 1週分)
labels = np.array([i[-1] for i in dw]).reshape(len(dw), 1)

上記では 1週分ずつスライドさせた 5週分の配列を作成し、4週分を入力データ、残りの 1週分をラベルデータとしています。

その結果、data・labels 変数の内容は以下のようになります。

data 変数の内容
array([[[  9.89100000e+03],
        [  2.71000000e+04],
        [  5.82330000e+04],
        [  1.22618000e+05]],

       [[  2.71000000e+04],
        [  5.82330000e+04],
        [  1.22618000e+05],
        [  1.70403000e+05]],

       [[  5.82330000e+04],
        [  1.22618000e+05],
        [  1.70403000e+05],
        [  1.51829000e+05]],

       ・・・

       [[  2.40700000e+03],
        [  2.58800000e+03],
        [  3.79900000e+03],
        [  7.28000000e+03]],

       [[  2.58800000e+03],
        [  3.79900000e+03],
        [  7.28000000e+03],
        [  1.27850000e+04]]])
labels 変数の内容
array([[  1.70403000e+05],
       [  1.51829000e+05],
       [  1.39162000e+05],
       ・・・
       [  1.27850000e+04],
       [  2.01270000e+04]])

学習

上記のデータを使って再帰ニューラルネットワークの学習処理を実施してみます。

今回は GRU(ゲート付き再帰的ユニット)を使いました。

学習処理例(Jupyter Notebook 使用)
%matplotlib inline
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.layers.recurrent import GRU
from keras.optimizers import Nadam

epoch = 3000
batch_size = 52
n_num = 80

model = Sequential()

# 再帰型ニューラルネットワークの GRU 使用
model.add(GRU(n_num, activation = 'relu', input_shape = (t, 1)))

model.add(Dense(1))
model.add(Activation('linear'))

opt = Nadam()

# モデルの構築
model.compile(loss = 'mean_squared_error', optimizer = opt)

# 学習
hist = model.fit(data, labels, epochs = epoch, batch_size = batch_size)

# 誤差の遷移状況をグラフ化
plt.plot(hist.history['loss'])

処理結果は以下の通り。 誤差の遷移を見る限り、学習は進んでいるように見えます。

学習処理結果例
Epoch 1/3000
202/202 [==============================] - 1s 7ms/step - loss: 3604321859.1683
Epoch 2/3000
202/202 [==============================] - 0s 193us/step - loss: 3120768191.3663
・・・
Epoch 2999/3000
202/202 [==============================] - 0s 183us/step - loss: 41551096.5149
Epoch 3000/3000
202/202 [==============================] - 0s 193us/step - loss: 41367050.6931

f:id:fits:20180121173527p:plain

予測

3つのグラフを描画して学習したモデルの予測能力を比べてみます。

名称 概要
actual 実データ
predict1 実データを入力データにした予測結果(学習時の入力データをそのまま使用)
predict2 実データの先頭のみを使った予測結果

predict2 は以下のように最初だけ実データを使って、予測結果を次の入力データの一部にして予測を繋げていった結果です。

  • 2014年 1 ~ 4週目を入力データとし 5週目を予測
  • 2014年 2 ~ 4週目と 5週目の予測値を入力データとし 6週目を予測
  • 2014年 3 ~ 4週目と 5 ~ 6週目の予測値を入力データとし 7週目を予測
  • 2014年 4週目と 5 ~ 7週目の予測値を入力データとし 8週目を予測
  • 5 ~ 8週目の予測値を入力データとし 9週目を予測

そのため、実際の予測では predict2 の結果が重要となります。

予測処理例(Jupyter Notebook 使用)
from functools import reduce

# 実データの描画
plt.plot(ds.values, label = 'actual')

# 入力データを使った予測(predict1)
res1 = model.predict(data)
# predict1 の結果を描画
plt.plot(range(t, len(res1) + t), res1, label = 'predict1')

# predict2 用の予測処理
def predict(a, b):
    r = model.predict(a[1])

    return (
        np.append(a[0], r), 
        np.append(a[1][:, 1:], np.array([r]), axis = 1)
    )

# 入力データの先頭
fst_data = data[0].reshape(1, t, 1)
# 入力データの先頭を使った予測(predict2)
res2, _ = reduce(predict, range(len(ds) - t), (np.array([]), fst_data))
# predict2 の結果を描画
plt.plot(range(t, len(res2) + t), res2, label = 'predict2')

plt.legend()

インフルエンザのデータに対して適用した結果が以下です。

予測結果(インフルエンザ)

f:id:fits:20180121173550p:plain

predict1 は実データをほぼ再現できていますが、predict2 は全く駄目でした。

やはり、入力データへ週の情報を与えなかった事が原因だと思われます。

備考

上記処理を単一スクリプト化したものが以下です。

sample1.py
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from functools import reduce
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.layers.recurrent import GRU
from keras.optimizers import Nadam

data_file = sys.argv[1]
item_name = sys.argv[2]
dest_file = sys.argv[3]

t = 4
epoch = 3000
batch_size = 52
n_num = 80

df = pd.read_csv(data_file, encoding = 'UTF-8')

ds = df.groupby(['year', 'week'])[item_name].sum()

def window(d, wsize):
    return [d[i:(i + wsize)].values.flatten() for i in range(len(d) - wsize + 1)]

dw = window(ds.astype('float'), t + 1)

data = np.array([i[0:-1] for i in dw]).reshape(len(dw), t, 1)
labels = np.array([i[-1] for i in dw]).reshape(len(dw), 1)

model = Sequential()

model.add(GRU(n_num, activation = 'relu', input_shape = (t, 1)))

model.add(Dense(1))
model.add(Activation('linear'))

opt = Nadam()

model.compile(loss = 'mean_squared_error', optimizer = opt)

hist = model.fit(data, labels, epochs = epoch, batch_size = batch_size)

fig, axes = plt.subplots(1, 2, figsize = (12, 6))

axes[0].plot(hist.history['loss'])

axes[1].plot(ds.values, label = 'actual')

res1 = model.predict(data)

axes[1].plot(range(t, len(res1) + t), res1, label = 'predict1')

def predict(a, b):
    r = model.predict(a[1])

    return (
        np.append(a[0], r), 
        np.append(a[1][:, 1:], np.array([r]), axis = 1)
    )

fst_data = data[0].reshape(1, t, 1)
res2, _ = reduce(predict, range(len(ds) - t), (np.array([]), fst_data))

axes[1].plot(range(t, len(res2) + t), res2, label = 'predict2')

axes[1].legend()

plt.savefig(dest_file)

(b) 年と週を追加したデータを使用

季節性の強いインフルエンザのようなデータに対して (a) の入力データでは無理があった気がするので、次は入力データへ年と週を追加して試してみます。

入力データとラベルデータ

(a) の入力データへラベルデータの年と週を追加します。

  • 2014 と 5、2014年 1 ~ 4週目を入力データとし 5週目をラベルデータ
  • 2014 と 6、2014年 2 ~ 5週目を入力データとし 6週目をラベルデータ
入力データ(data 変数の内容)例
array([[[  2.01400000e+03],
        [  5.00000000e+00],
        [  9.89100000e+03],
        [  2.71000000e+04],
        [  5.82330000e+04],
        [  1.22618000e+05]],

       [[  2.01400000e+03],
        [  6.00000000e+00],
        [  2.71000000e+04],
        [  5.82330000e+04],
        [  1.22618000e+05],
        [  1.70403000e+05]],

       ・・・

       [[  2.01700000e+03],
        [  4.80000000e+01],
        [  2.40700000e+03],
        [  2.58800000e+03],
        [  3.79900000e+03],
        [  7.28000000e+03]],

       [[  2.01700000e+03],
        [  4.90000000e+01],
        [  2.58800000e+03],
        [  3.79900000e+03],
        [  7.28000000e+03],
        [  1.27850000e+04]]])
ラベルデータ(labels 変数の内容)例
array([[  1.70403000e+05],
       [  1.51829000e+05],
       ・・・
       [  1.27850000e+04],
       [  2.01270000e+04]])

学習と予測処理

発症報告数と週の値では桁数が大きく異なるケースがあるので、このままで効果的な学習ができるかは微妙です。

そこで、ここではバッチ正規化(BatchNormalization)の層を先頭に追加する事で対処してみました。

バッチ正規化によりミニバッチ単位でデータの正規化が行われるようになります。 この場合、ミニバッチのサイズが重要になってくると思います。

また、IDWR では ISO 週番号を使っているようなので、年によって 52週の場合と 53週の場合があります。

ここでは isocalendar でその年の最終週を取得する事で対応しました。(12/28 を含む週がその年の最終週)

sample2.py
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
from functools import reduce
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.recurrent import GRU
from keras.optimizers import Nadam

data_file = sys.argv[1]
item_name = sys.argv[2]
dest_file = sys.argv[3]
predict_size = int(sys.argv[4])
batch_size = int(sys.argv[5])

t = 4
epoch = 5000
n_num = 80

df = pd.read_csv(data_file, encoding = 'UTF-8')

ds = df.groupby(['year', 'week'])[item_name].sum()

def window_with_index(d, wsize):
    return [
        np.r_[
            d.index[i + wsize - 1][0], # 年
            d.index[i + wsize - 1][1], # 週
            d[i:(i + wsize)].values.flatten()
        ] for i in range(len(d) - wsize + 1)
    ]

dw = window_with_index(ds.astype('float'), t + 1)

# 入力データの要素数
input_num = len(dw[0]) - 1

data = np.array([i[0:-1] for i in dw]).reshape(len(dw), input_num, 1)
labels = np.array([i[-1] for i in dw]).reshape(len(dw), 1)

model = Sequential()

# バッチ正規化
model.add(BatchNormalization(axis = 1, input_shape = (input_num, 1)))

model.add(GRU(n_num, activation = 'relu'))

model.add(Dense(1))
model.add(Activation('linear'))

opt = Nadam()

model.compile(loss = 'mean_squared_error', optimizer = opt)

hist = model.fit(data, labels, epochs = epoch, batch_size = batch_size)

fig, axes = plt.subplots(1, 2, figsize = (12, 6))

axes[0].plot(hist.history['loss'])

# 実データの描画
axes[1].plot(ds.values, label = 'actual')

# predict1(入力データをそのまま使用)
res1 = model.predict(data)

axes[1].plot(range(t, len(res1) + t), res1, label = 'predict1')

# 指定年の最終週(52 か 53)の取得
def weeks_of_year(year):
    return datetime(year, 12, 28).isocalendar()[1]

# predict2 用の予測処理
def predict(a, b):
    r = model.predict(a[1])

    year = a[1][0, 0, 0]
    week = a[1][0, 1, 0] + 1

    if week > weeks_of_year(int(year)):
        year += 1
        week = 1

    next = np.r_[
        year,
        week,
        a[1][:, 3:].flatten(),
        r.flatten()
    ].reshape(a[1].shape)

    return (np.append(a[0], r), next)

fst_data = data[0].reshape(1, input_num, 1)
# predict2(入力データの先頭のみ使用)
res2, _ = reduce(predict, range(predict_size), (np.array([]), fst_data))

axes[1].plot(range(t, predict_size + t), res2, label = 'predict2')

min_year = min(ds.index.levels[0])
years = range(min_year, min_year + int(predict_size / 52) + 1)

# X軸のラベル設定
axes[1].set_xticklabels(years)
# X軸の目盛設定
axes[1].set_xticks(
    reduce(lambda a, b: a + [a[-1] + weeks_of_year(b)], years[:-1], [0])
)

# 凡例をグラフ外(下部)へ横並びで出力
axes[1].legend(bbox_to_anchor = (1, -0.1), ncol = 3)

fig.subplots_adjust(bottom = 0.15)

plt.savefig(dest_file)

実行

インフルエンザに対して 2014年 5週目から 250週分をバッチサイズ 52 で予測(predict2)するように実行してみます。

実行例(インフルエンザ)
> python sample2.py idwr.csv インフルエンザ sample2.png 250 52

・・・
Epoch 4998/5000
202/202 [==============================] - 0s 155us/step - loss: 133048118.4158
Epoch 4999/5000
202/202 [==============================] - 0s 155us/step - loss: 154247298.2970
Epoch 5000/5000
202/202 [==============================] - 0s 232us/step - loss: 171130778.4554
1. インフルエンザ (250週予測, batch_size = 52)

f:id:fits:20180121173637p:plain

predict2 は劇的に改善され、将来の予測値もそれっぽいものになりました。

他の感染症に対しても同じように処理してみると以下のようになりました。

2. 感染性胃腸炎 (250週予測, batch_size = 52)

f:id:fits:20180121173656p:plain

3. 手足口病1 (250週予測, batch_size = 52)

f:id:fits:20180121173714p:plain

2年周期を学習できていないような結果となりました。

そこで、バッチサイズを 104 にして 320週分を予測してみたところ以下のようになりました。

4. 手足口病2 (320週予測, batch_size = 104)

f:id:fits:20180121173731p:plain

5. 水痘1 (250週予測, batch_size = 52)

f:id:fits:20180121173748p:plain

predict2 の差異が目立っています。

そこで、バッチサイズを 26 にしてみたところ改善が見られました。 ちなみに、バッチサイズを大きくしても特に改善は見られませんでした。

6. 水痘2 (250週予測, batch_size = 26)

f:id:fits:20180121173805p:plain

7. 流行性耳下腺炎 (250週予測, batch_size = 52)

f:id:fits:20180121173828p:plain

IDWR データの入手と加工

時系列データ分析を試すのに適した季節性(周期性)を持つオープンデータを探していて以下を見つけました。

インフルエンザ等の感染症の週単位の報告数が都道府県別にまとまっており、csv ファイルでデータを入手できるようになっています。

今回は、上記データを分析するための前処理(入手と加工)を行います。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180114/

(1) データ入手

csv ファイルは IDWR速報データ から入手できます。

「定点把握疾患(週報告)、報告数、定点当たり報告数、都道府県別」の場合、csv ファイル名は <年4桁>-<週2桁>-teiten.csv となっています。

週は 1 ~ 52 もしくは 53 となっており、年によって 52 の場合と 53 の場合があります。

URL のルールが決まっているので過去データも簡単にダウンロードできましたが、curl コマンド等を使ってダウンロードする際は HTTP リクエストヘッダーへ User-Agent を付ける必要がありました。

teiten.csv 内容

teiten.csv の内容は以下の通りで、文字コードShift_JIS となっています。

これをそのままデータ分析で使うには以下の点が気になります。

  • ヘッダー行が複数行
  • 報告がない場合は - となっている
  • 最終行が空データ
2017-44-teiten.csv
"報告数・定点当り報告数、疾病・都道府県別","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","",""
"2017年44週(10月30日~11月05日)","2017年11月08日作成","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","",""
"","インフルエンザ","","RSウイルス感染症","","咽頭結膜熱","","A群溶血性レンサ球菌咽頭炎","","感染性胃腸炎","","水痘","","手足口病","","伝染性紅斑","","突発性発しん","","百日咳","","ヘルパンギーナ","","流行性耳下腺炎","","急性出血性結膜炎","","流行性角結膜炎","","細菌性髄膜炎","","無菌性髄膜炎","","マイコプラズマ肺炎","","クラミジア肺炎","","感染性胃腸炎(ロタウイルス)",""
"","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当","報告","定当"
"総数","2407","0.49","3033","","1621","0.51","5940","1.88","10937","3.47","1469","0.47","5126","1.62","173","0.05","1259","0.40","50","0.02","967","0.31","899","0.28","8","0.01","484","0.70","9","0.02","11","0.02","185","0.39","5","0.01","4","0.01"
"北海道","137","0.62","120","","369","2.65","384","2.76","217","1.56","71","0.51","103","0.74","2","0.01","33","0.24","5","0.04","9","0.06","11","0.08","-","-","9","0.31","-","-","1","0.04","9","0.39","-","-","1","0.04"
"青森県","10","0.15","35","","12","0.29","44","1.05","74","1.76","21","0.50","121","2.88","-","-","11","0.26","1","0.02","11","0.26","28","0.67","-","-","4","0.36","-","-","-","-","5","0.83","-","-","-","-"
・・・
"鹿児島県","79","0.87","37","","46","0.87","76","1.43","281","5.30","20","0.38","134","2.53","-","-","20","0.38","-","-","20","0.38","81","1.53","-","-","6","0.86","-","-","-","-","2","0.17","-","-","-","-"
"沖縄県","230","3.97","13","","16","0.47","25","0.74","71","2.09","7","0.21","42","1.24","-","-","10","0.29","3","0.09","7","0.21","3","0.09","-","-","13","1.44","-","-","-","-","-","-","1","0.14","-","-"
"","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","",""

f:id:fits:20180114215531p:plain

(2) 加工

複数の teiten.csv から必要な情報のみを抽出して単一の csv へ変換する処理を Python で実装してみました。

基本的な処理は pandas.read_csv で行っています。

pandas.read_csv の引数 内容
skiprows 不要なヘッダー行と総数の行を除外(総数は算出できるため)
skipfooter 最終行の空データを除外
usecols "定当" 列を除外
converters '-' を 0 へ変換

時系列データ分析用に年と週、そして週の最終日を追加しています。

idwr_convert.py
# coding: utf-8
import codecs
import functools
import glob
import os
import re
import sys
import pandas as pd

data_files = f"{sys.argv[1]}/*.csv"
dest_file = sys.argv[2]

r = re.compile('"([0-9]+)年([0-9]+)週\(.*[^0-9]([0-9]+)月([0-9]+)日\)')

# 抽出する列(定当データの除外)
cols = [i for i in range(38) if i == 0 or i % 2 == 1]

# 変換処理('-' を 0 へ変換)
conv = lambda x: 0 if x == '-' else int(x)

# 列毎の変換処理
conv_map = {i:conv for i in range(len(cols)) if i > 0}

# 年、週、その週の最終日を抽出
def read_info(file):
    f = codecs.open(file, 'r', 'Shift_JIS')
    f.readline()

    m = r.match(f.readline())

    f.close()

    year = int(m.group(1))
    week = int(m.group(2))
    # 50週を超えていて 1月なら次の年
    last_year = year + 1 if week > 50 and m.group(3) == '01' else year

    return (year, week, f"{last_year}-{m.group(3)}-{m.group(4)}")

def read_csv(file):
    d = pd.read_csv(file, encoding = 'Shift_JIS', skiprows = [0, 1, 3, 4],
                     usecols = cols, skipfooter = 1, converters = conv_map, 
                     engine = 'python')

    info = read_info(file)

    d['year'] = info[0] # 年
    d['week'] = info[1] # 週
    d['lastdate'] = info[2] # その週の最終日

    return d.rename(columns = {'Unnamed: 0': 'pref'})


dfs = [read_csv(f) for f in glob.glob(data_files)]

# データフレームの連結
df = functools.reduce(lambda a, b: a.append(b), dfs)

# csv ファイルとして出力
df.to_csv(dest_file, encoding = 'UTF-8')

実行

2014年 1週目 ~ 2017年 49週目 (2015年は 53週目まである) の teiten.csv を data ディレクトリを配置して、上記を実行しました。

実行例
> python idwr_convert.py data idwr.csv
idwr.csv
,pref,インフルエンザ,RSウイルス感染症,咽頭結膜熱,A群溶血性レンサ球菌咽頭炎,感染性胃腸炎,水痘,手足口病,伝染性紅斑,突発性発しん,百日咳,ヘルパンギーナ,流行性耳下腺炎,急性出血性結膜炎,流行性角結膜炎,細菌性髄膜炎,無菌性髄膜炎,マイコプラズマ肺炎,クラミジア肺炎,感染性胃腸炎(ロタウイルス),year,week,lastdate
0,北海道,333,125,40,118,203,236,2,2,12,0,5,4,0,1,0,0,0,0,1,2014,1,2014-01-05
1,青森県,77,22,18,9,134,71,0,10,4,0,0,7,0,0,0,0,0,0,2,2014,1,2014-01-05
2,岩手県,98,17,11,25,224,82,2,3,7,0,0,15,0,7,0,1,10,0,1,2014,1,2014-01-05
・・・
44,宮崎県,20,47,13,58,191,23,9,18,15,0,1,55,0,11,0,0,0,0,0,2015,53,2016-01-03
45,鹿児島県,40,27,50,109,303,40,4,32,21,0,0,63,0,7,0,0,0,0,0,2015,53,2016-01-03
46,沖縄県,353,3,5,30,209,30,1,2,7,3,1,57,0,2,1,1,9,0,0,2015,53,2016-01-03
0,北海道,1093,147,116,491,433,101,17,273,39,2,1,306,0,6,1,0,11,0,11,2016,1,2016-01-10
1,青森県,142,31,23,52,201,14,0,24,14,0,0,38,0,7,2,1,3,0,0,2016,1,2016-01-10
2,岩手県,156,33,4,125,255,14,2,15,18,1,2,29,0,13,1,0,14,0,0,2016,1,2016-01-10
・・・
44,宮崎県,347,46,81,98,358,14,40,0,29,2,4,24,0,15,0,1,0,0,0,2017,49,2017-12-10
45,鹿児島県,252,24,83,134,426,38,71,1,18,0,11,86,0,8,0,0,0,0,0,2017,49,2017-12-10
46,沖縄県,410,5,8,44,70,24,72,2,9,0,6,4,0,20,0,1,0,0,0,2017,49,2017-12-10

(3) グラフ化

idwr.csv感染症毎に集計して折れ線グラフ化してみます。

(a) matplotlib 使用 - 全体

まずは matplotlib を使って全ての感染症を同一グラフへ表示してみます。

idwr_plot_matplotlib.py
import sys
import pandas as pd
import matplotlib.pyplot as plt

data_file = sys.argv[1]
img_file = sys.argv[2]

df = pd.read_csv(data_file, parse_dates = ['lastdate'])

df.groupby('lastdate').sum().iloc[:, 1:20].plot(legend = False)

plt.savefig(img_file)

ここで、df.groupby('lastdate').sum() では lastdate 毎にグルーピングして合計しています。

f:id:fits:20180114215613p:plain

これだと不要な列も含んでしまうので df.groupby('lastdate').sum().iloc[:, 1:20] で必要な列のみを抽出します。

f:id:fits:20180114215636p:plain

実行結果は以下の通りです。

実行例
> python idwr_plot_matplotlib.py idwr.csv idwr_matplotlib.png
idwr_matplotlib.png

f:id:fits:20180114220012p:plain

他と比べ目立って報告数が多いのはインフルエンザです。

(b) matplotlib 使用 - 個別

次に指定したものだけをグラフ表示してみます。

idwr_plot_matplotlib2.py
import sys
import pandas as pd
import matplotlib.pyplot as plt

data_file = sys.argv[1]
item_name = sys.argv[2]
img_file = sys.argv[3]

df = pd.read_csv(data_file, parse_dates = ['lastdate'])

df.groupby('lastdate').sum()[item_name].plot(legend = False)

plt.savefig(img_file)
実行例
> python idwr_plot_matplotlib2.py idwr.csv インフルエンザ インフルエンザ.png

グラフ例

グラフ形状的に特徴のあるものをいくつかピックアップしてみました。

1. インフルエンザ

1年周期で一定期間内に大流行しています。

f:id:fits:20180114220104p:plain

2. 感染性胃腸炎

1年周期ですが、2016年の末頃は(異常に)大流行しています。

f:id:fits:20180114220118p:plain

3. 手足口病

2年周期で流行しています。

f:id:fits:20180114220133p:plain

4. 水痘

なんとなく 1年周期がありそうですが、全体的な傾向は下がっているように見えます。

f:id:fits:20180114220149p:plain

5. 流行性耳下腺炎

目立った周期性は無さそうです。

f:id:fits:20180114220202p:plain

このように、なかなか興味深いデータが揃っているように思います。

(c) HoloViews + bokeh 使用

最後に、HoloViews + bokeh でインタラクティブに操作できるグラフを作成します。

とりあえず、感染症(列)毎に折れ線グラフ(Curve)を作って Overlay で重ねてみました。

opts でグラフのサイズやフォントサイズを変更しています。

なお、Curve で label = c とすると (label の値は凡例に使われる)、正常に動作しなかったため、回避措置として label = f"'{c}'" のようにしています。

idwr_plot_holoviews.py
import sys
import pandas as pd
import holoviews as hv

hv.extension('bokeh')

data_file = sys.argv[1]
dest_file = sys.argv[2]

df = pd.read_csv(data_file, parse_dates = ['lastdate'])

dg = df.groupby('lastdate').sum().iloc[:, 1:20]

# 感染症毎に Curve を作成
plist = [hv.Curve(dg[c].reset_index().values, label = f"'{c}'") for c in dg]

# 複数の Curve を重ねる
p = hv.Overlay(plist)

# グラフのサイズを変更
p = p.opts(plot = dict(width = 800, height = 600, fontsize = 8))
# X軸・Y軸のラベルを変更
p = p.redim.label(x = 'lastdate', y = 'num')

# グラフの保存
hv.renderer('bokeh').save(p, dest_file)

実行結果は以下の通り。 拡張子 .html は勝手に付与されるようです。

実行例
> python idwr_plot_holoviews.py idwr.csv idwr_holoviews
idwr_holoviews.html 表示例

f:id:fits:20180114220302p:plain

Python でアソシエーション分析 - Orange3-Associate

前回 と同様のアソシエーション分析を PythonOrange で試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180109/

はじめに

データセット

前回 と同じデータファイルを使います。

data.basket
C,S,M,R
T,Y,C
P,Y,C,M
O,W,L
R
O,U,R,L
P
W,C
T,O,W,B,C
C,T,W,B,Y,F,D
・・・
R,P,S
B,
B,S
B,F
C,F,N

インストール

今回は Orange3 を使います。

Orange3 ではアソシエーション分析に関する処理を Orange3-Associate へ分離しているようなので、これらをインストールしておきます。

Orange3 インストール

Orange3 自体は conda コマンドでインストールできるようです。

Orange3 インストール例
> conda install orange3

より新しいバージョンをインストールするには conda-forge を使えば良さそうです。

Orange3 インストール例(conda-forge 利用)
> conda config --add channels conda-forge
> conda install orange3

Orange3-Associate インストール

試した時点では、Orange3-Associate を conda や pip でインストールできなかったので、ソースを取得してインストールしました。

Orange3-Associate インストール例
> git clone https://github.com/biolab/orange3-associate.git
> cd orange3-associate
> python setup.py install

実装と実行

前回 と同様の処理を実装してみます。

(a) リフト値なし

まずは、Orange.data.Table でデータファイルを読み込みます。 (ファイル名の拡張子は .basket とする必要がありそう)

その結果、tbl 変数の内容は [[C=1.000, S=1.000, M=1.000, R=1.000], [C=1.000, T=1.000, Y=1.000], ・・・] のようになります。

C や S のような文字列では処理できないようなので、OneHot.encode で One hot 表現化します。

その結果、X 変数の内容は [[0, 1, 2, 3], [0, 4, 5], ・・・] のようになります。(出現した順に 0 からの連番が割り当てられるようです)

frequent_itemsets で組み合わせ毎の発生件数をカウントします。 第 2引数の min_support で抽出する support(支持度)の最小値を指定できます。以下のサンプルでは 5 と指定しているので 5件以上のものが抽出されます。(比率の指定も可能)

itemsets 変数の内容は {frozenset({11}): 42, frozenset({0}): 41, frozenset({0, 11}): 12, ・・・} のようになります。

association_rules でアソシエーションルールの抽出を行います。 第 2引数の min_confidence で抽出する confidence(確信度)の最小値を指定できます。

ただし、association_rules ではリフト値を取得できないようです。

PQ 変数の内容は frozenset({11, 7}) のようになるので、元の文字列へ戻すために OneHot.decode で処理します。

OneHot.decode の結果は [(11, ContinuousVariable(name='B', number_of_decimals=3), 0), (7, ContinuousVariable(name='O', number_of_decimals=3), 0)] のようになるので ContinuousVariable の name の値を取り出しています。

sample.py
import sys
import Orange
from orangecontrib.associate.fpgrowth import *

data_file = sys.argv[1]

# データファイル読み込み
tbl = Orange.data.Table(data_file)

X, mapping = OneHot.encode(tbl)

itemsets = dict(frequent_itemsets(X, 5))

# アソシエーションルールの抽出
rules = association_rules(itemsets, 0.7)

def decode_onehot(d):
    items = OneHot.decode(d, tbl, mapping)
    # ContinuousVariable の name 値を取得
    return list(map(lambda v: v[1].name, items))

for P, Q, support, confidence in rules:
    lhs = decode_onehot(P)
    rhs = decode_onehot(Q)

    print(f"lhs = {lhs}, rhs = {rhs}, support = {support}, confidence = {confidence}")

実行結果は以下の通り。

実行結果
> python sample.py data.basket

lhs = ['B', 'O'], rhs = ['W'], support = 5, confidence = 0.8333333333333334
lhs = ['B', 'T'], rhs = ['C'], support = 5, confidence = 1.0
lhs = ['N'], rhs = ['C'], support = 10, confidence = 0.7142857142857143
lhs = ['T'], rhs = ['C'], support = 8, confidence = 0.8

(b) リフト値あり

リフト値の取得には rules_stats を使います。

rules_stats の第 3引数にはデータセットの件数(今回は 100)を指定します。(この値はリフト値の算出に使われる)

sample2.py
import sys
import Orange
from orangecontrib.associate.fpgrowth import *

data_file = sys.argv[1]

tbl = Orange.data.Table(data_file)

X, mapping = OneHot.encode(tbl)

itemsets = dict(frequent_itemsets(X, 5))

# アソシエーションルールの抽出
rules = association_rules(itemsets, 0.7)

# リフト値を含んだ結果を取得
stats = rules_stats(rules, itemsets, len(X))

def decode_onehot(d):
    items = OneHot.decode(d, tbl, mapping)
    return list(map(lambda v: v[1].name, items))

# リフト値(7番目の要素)でソート
for s in sorted(stats, key = lambda x: x[6], reverse = True):

    lhs = decode_onehot(s[0])
    rhs = decode_onehot(s[1])

    support = s[2]
    confidence = s[3]
    lift = s[6]

    print(f"lhs = {lhs}, rhs = {rhs}, support = {support}, confidence = {confidence}, lift = {lift}")

実行結果は以下の通り。

R の arules を使った 前回 と概ね同じ結果になりました。

実行結果
> python sample2.py data.basket

lhs = ['B', 'O'], rhs = ['W'], support = 5, confidence = 0.8333333333333334, lift = 3.333333333333334
lhs = ['B', 'T'], rhs = ['C'], support = 5, confidence = 1.0, lift = 2.4390243902439024
lhs = ['T'], rhs = ['C'], support = 8, confidence = 0.8, lift = 1.951219512195122
lhs = ['N'], rhs = ['C'], support = 10, confidence = 0.7142857142857143, lift = 1.7421602787456447

R でアソシエーション分析 - arules

R言語arules を使ってアソシエーション分析を試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180108/

はじめに

データセット

今回は、適当に作った下記データセット (100行) を使います。

1行が 1つの取引で , で区切られたアルファベット群が同時に購入された商品とみなします。(例えば、T,Y,C なら T と Y と C 商品を同時購入)

data.basket
C,S,M,R
T,Y,C
P,Y,C,M
O,W,L
R
O,U,R,L
P
W,C
T,O,W,B,C
C,T,W,B,Y,F,D
・・・
R,P,S
B,
B,S
B,F
C,F,N

インストール

install.packages で arules をインストールしておきます。

インストール例
> install.packages("arules")

実装

arules の apriori 関数で Apriori アルゴリズムを使ったアソシエーションルールの抽出を行えます。

read.transactions でデータファイルを transactions オブジェクト化して、apriori の data 引数として使います。

parameter 引数を使って support(支持度)が 0.05 以上 ※、confidence(確信度)が 0.7 以上のものを抽出するようにしました。

 ※ 今回使用するデータセットの件数が 100件なので 0.05 は 5件になる
sample.R
library(arules)

args <- commandArgs(TRUE)

# データファイルを transactions オブジェクト化
tr <- read.transactions(args[1], format = "basket", sep = ",")

# アソシエーションルールの抽出
tr.ap <- apriori(tr, parameter = list(support = 0.05, confidence = 0.7))

# lift 値でソート
inspect(sort(tr.ap, by = "lift"))

実行

Rscript で実行してみます。

実行結果(support = 0.05, confidence = 0.7)
> Rscript sample.R data.basket

・・・
    lhs      rhs support confidence lift     count
[1] {B,O} => {W} 0.05    0.8333333  3.333333  5
[2] {B,T} => {C} 0.05    1.0000000  2.439024  5
[3] {T}   => {C} 0.08    0.8000000  1.951220  8
[4] {N}   => {C} 0.10    0.7142857  1.742160 10

補足

なお、当然ながら support や confidence の条件を緩和すると結果数が大きく変わります。

実行結果2(support = 0.03, confidence = 0.5)
> Rscript sample2.R data.basket

・・・
      lhs        rhs support confidence lift     count
[1]   {B,C,W} => {T} 0.04    0.8000000  8.000000  4
[2]   {B,C,F} => {T} 0.03    0.7500000  7.500000  3
[3]   {B,C,R} => {T} 0.03    0.7500000  7.500000  3
[4]   {F,W}   => {T} 0.03    0.6000000  6.000000  3
[5]   {B,R,S} => {Y} 0.03    0.6000000  6.000000  3
・・・
[128] {O,W}   => {B} 0.05    0.5000000  1.190476  5
[129] {D,W}   => {B} 0.03    0.5000000  1.190476  3
[130] {D,S}   => {B} 0.03    0.5000000  1.190476  3

R の MXNet で iris を分類

MXNet で iris を分類」 と同様の処理を R言語で実装してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20171212/

準備

今回は下記サイトの手順に従って MXNet R パッケージの CPU 版を Windows へインストールしました。

MXNet R パッケージの CPU 版を Windows へインストール
cran <- getOption("repos")
cran["dmlc"] <- "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/CRAN/"
options(repos = cran)
install.packages("mxnet")

インストールした mxnet のバージョンが少し古いようですが(現時点の MXNet 最新バージョンは 1.0)、今回はこれを使います。

バージョン確認
> packageVersion('mxnet')
[1]0.10.1’

学習と評価

MXNet には、階層型ニューラルネットワークの学習処理を簡単に実行するための関数 mx.mlp が用意されているので、今回はこれを使います。

引数 備考
hidden_node 隠れ層のノード(ニューロン)数(デフォルトは 1)
out_node 出力ノード数(今回は分類数)
num.round 繰り返し回数(デフォルトは 10)
array.batch.size バッチサイズ(デフォルトは 128)
learning.rate 学習係数
activation 活性化関数(デフォルトは tanh

hidden_node にベクトル(例. c(6, 4))を設定すれば隠れ層が複数階層化されるようです。

iris のデータセットは R に用意されているものを使います。

mx.mlp の入力データには mx.io.DataIter か R の配列 / 行列を使う必要があるようなので(ラベルデータは配列のみ)、data.matrix で行列化しています。

ラベルデータとする iris の種別 iris$Species は因子型ですが、mxnet では因子型を扱えないようなので as.numeric で数値化しています。

ここで as.numeric の結果は 1 ~ 3 の数値になりますが、mxnet で 3種類の分類を行うには 0 ~ 2 でなければならないようなので -1 しています。

一方、predict の結果を max.col(t(<predictの結果>)) で処理すると 1 ~ 3 の数値になるため、評価用のラベルデータは -1 せずに正解率の算出に使っています。

また、array.layout = 'rowmajor' は Warning message 抑制のために設定しています。

iris_hnn.R
library(mxnet)

train_size = 0.7

n = nrow(iris)
# 1 ~ n から無作為に n * train_size 個を抽出
perm = sample(n, size = round(n * train_size))

# 学習用データ
train <- iris[perm, ]
# 評価用データ
test <-iris[-perm, ]

# 学習用入力データ
train.x <- data.matrix(train[1:4])
# 学習用ラベルデータ(0 ~ 2)
train.y <- as.numeric(train$Species) - 1

# 評価用入力データ
test.x <- data.matrix(test[1:4])
# 評価用ラベルデータ(1 ~ 3)
test.y <- as.numeric(test$Species)

mx.set.seed(0)

# 学習
model <- mx.mlp(train.x, train.y, 
                hidden_node = 5, 
                out_node = 3,
                num.round = 100,
                learning.rate = 0.1,
                array.batch.size = 10,
                activation = 'relu',
                array.layout = 'rowmajor',
                eval.metric = mx.metric.accuracy)

# 評価
pred <- predict(model, test.x, array.layout = 'rowmajor')

# 評価用データの分類結果(1 ~ 3)
pred.y <- max.col(t(pred))

# 評価データの正解率を算出
acc <- sum(pred.y == test.y) / length(pred.y)

print(acc)

実行結果は以下の通り。

実行結果
・・・
> model <- mx.mlp(train.x, train.y, 
+                 hidden_node = 5, 
+                 out_node = 3,
+                 num.round = 100,
+                 learning.rate = 0.1,
+                 array.batch.size = 10,
+                 activation = 'relu',
+                 array.layout = 'rowmajor',
+                 eval.metric = mx.metric.accuracy)
Start training with 1 devices
[1] Train-accuracy=0.32
[2] Train-accuracy=0.281818181818182
・・・
[99] Train-accuracy=0.954545454545455
[100] Train-accuracy=0.954545454545455
・・・
> print(acc)
[1] 0.9555556

備考

predict の実行結果は以下のような内容となっています。

> pred

            [,1]        [,2]        [,3]        [,4]
[1,] 0.968931615 0.968931615 0.968931615 0.968931615
[2,] 0.029328469 0.029328469 0.029328469 0.029328469
[3,] 0.001739914 0.001739914 0.001739914 0.001739914
            [,5]        [,6]        [,7]        [,8]
[1,] 0.968931615 0.968931615 0.968931615 0.968931615
[2,] 0.029328469 0.029328469 0.029328469 0.029328469
[3,] 0.001739914 0.001739914 0.001739914 0.001739914
・・・
            [,41]        [,42]        [,43]        [,44]
[1,] 1.762393e-08 7.670556e-06 5.799695e-06 9.349569e-12
[2,] 3.053433e-02 2.679898e-01 1.714197e-01 7.250102e-05
[3,] 9.694657e-01 7.320026e-01 8.285745e-01 9.999275e-01
            [,45]
[1,] 4.420018e-08
[2,] 8.569881e-03
[3,] 9.914301e-01

1 ~ 3 の中で最も数値の高いものが分類結果となりますので、上記t で転置して max.col すると以下のようになります。

> max.col(t(pred))

 [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 3 2 2 2 2 2
[30] 3 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3

Elixir でステートマシンを処理

「Akka でステートマシンを処理」 と同じ処理を gen_statem の Elixir 用ラッパー(以下)を使って実装します。

ソースは http://github.com/fits/try_samples/tree/master/blog/20171121/

ステートマシンの実装(sample1)

まず、以下のステートマシンを実装します。

  • 初期状態は Idle 状態
  • Idle 状態で On イベントが発生すると Active 状態へ遷移
  • Active 状態で Off イベントが発生すると Idle 状態へ遷移
現在の状態 Off On
Idle Active
Active Idle

準備

mix でプロジェクトを作成します。

プロジェクト作成
> mix new sample1

・・・
> cd sample1

mix.exs の deps へ gen_state_machine を追加します。 今回は escript で実行するので、そのための設定も追加しておきます。

mix.exs
defmodule Sample1.Mixfile do
  use Mix.Project

  def project do
    [
      ・・・
      deps: deps(),
      # escript の設定
      escript: [ main_module: Sample1 ]
    ]
  end

  ・・・

  defp deps do
    [
      # GenStateMachine
      {:gen_state_machine, "~> 2.0"}
    ]
  end
end

実装

handle_event(event_type, event_content, state, data) 関数でイベントをハンドリングし、{:next_state, <遷移先の状態>, <新しいデータ>} を返せば新しい状態へ遷移します。

状態遷移しない場合は {:keep_state, <データ>}:keep_state_and_data でも可)を返します。

ここで、cast 関数(返事を待たない一方通行の呼び出し。Akka の tell と同じ)を使った場合のイベントタイプは :cast となります。

lib/sample_state_machine.ex
defmodule SampleStateMachine do
  use GenStateMachine

  # 初期状態
  def init(_args) do
    {:ok, :idle, 0}
  end

  # on イベントの処理(idle から active へ遷移)
  def handle_event(:cast, :on, :idle, data) do
    IO.puts "*** :on, idle -> active"
    {:next_state, :active, data + 1}
  end

  # off イベントの処理(active から idle へ遷移)
  def handle_event(:cast, :off, :active, data) do
    IO.puts "*** :off, active -> idle"
    {:next_state, :idle, data}
  end

  # 上記以外
  def handle_event(event_type, event_content, state, data) do
    IO.puts "*** Unhandled: type=#{event_type}, content=#{event_content}, state=#{state}, data=#{data}"
    {:keep_state, data}
    # 以下でも可
    # {:keep_state_and_data, []}
  end
end

このステートマシンを動作確認するための処理を実装します。 escript で実行できるように main 関数内に定義しています。

lib/sample1.ex
defmodule Sample1 do
  def main(_args) do
    {:ok, pid} = GenStateMachine.start_link(SampleStateMachine, nil)

    GenStateMachine.cast(pid, :on)
    GenStateMachine.cast(pid, :off)

    GenStateMachine.cast(pid, :off)

    GenStateMachine.stop(pid)
  end
end

ビルドと実行

deps.get で gen_state_machine を取得します。

依存パッケージの取得
> mix deps.get

Running dependency resolution...
Dependency resolution completed:
  gen_state_machine 2.0.1
* Getting gen_state_machine (Hex package)
  Checking package (https://repo.hex.pm/tarballs/gen_state_machine-2.0.1.tar)
  Using locally cached package

escript.build で escript 実行用にビルドします。

ビルド
> mix escript.build

==> gen_state_machine
Compiling 3 files (.ex)
Generated gen_state_machine app
==> sample1
Compiling 2 files (.ex)
Generated sample1 app
Generated escript sample1 with MIX_ENV=dev

escript コマンドで実行します。

実行結果
> escript sample1

*** :on, idle -> active
*** :off, active -> idle
*** Unhandled: type=cast, content=off, state=idle, data=1

タイムアウト付きステートマシンの実装(sample2)

次に、タイムアウト時の遷移を追加してみます。

現在の状態 Off On Timeout (2秒)
Idle Active
Active Idle Idle

実装

{:next_state, <遷移先の状態>, <新しいデータ>, <タイムアウト(ミリ秒)>} を返すとタイムアウトを設定できます。

イベントタイプ :timeoutタイムアウトをハンドリングできます。

注意点として、このタイムアウトは状態自体のタイムアウトではなくイベントの受信に対するタイムアウトです。

lib/timeout_state_machine.ex
defmodule TimeoutStateMachine do
  use GenStateMachine

  def init(_args) do
    {:ok, :idle, 0}
  end

  def handle_event(:cast, :on, :idle, data) do
    IO.puts "*** :on, idle -> active"
    # 2秒タイムアウト
    {:next_state, :active, data + 1, 2000}
  end

  def handle_event(:cast, :off, :active, data) do
    IO.puts "*** :off, active -> idle"
    {:next_state, :idle, data}
  end

  # タイムアウト時の処理
  def handle_event(:timeout, event_content, :active, data) do
    IO.puts "*** :timeout content=#{event_content}, active -> idle"
    {:next_state, :idle, data}
  end

  def handle_event(event_type, event_content, state, data) do
    IO.puts "*** Unhandled: type=#{event_type}, content=#{event_content}, state=#{state}, data=#{data}"
    {:keep_state, data}
  end
end

動作確認の処理を実装します。

lib/sample2.ex
defmodule Sample2 do
  def main(_args) do
    {:ok, pid} = GenStateMachine.start_link(TimeoutStateMachine, nil)

    GenStateMachine.cast(pid, :on)
    GenStateMachine.cast(pid, :off)

    GenStateMachine.cast(pid, :off)

    GenStateMachine.cast(pid, :on)

    :timer.sleep(2500)

    GenStateMachine.cast(pid, :on)

    :timer.sleep(1500)

    GenStateMachine.cast(pid, :invalid_message)

    :timer.sleep(1500)

    GenStateMachine.cast(pid, :invalid_message)

    :timer.sleep(2500)

    GenStateMachine.stop(pid)
  end
end

ビルドと実行

依存パッケージの取得とビルド
> mix deps.get
・・・

> mix escript.build
・・・

実行結果は以下の通り、invalid_message のハンドリングでタイムアウトは機能しなくなっています。

実行結果
> escript sample2

*** :on, idle -> active
*** :off, active -> idle
*** Unhandled: type=cast, content=off, state=idle, data=1
*** :on, idle -> active
*** :timeout content=2000, active -> idle
*** :on, idle -> active
*** Unhandled: type=cast, content=invalid_message, state=active, data=3
*** Unhandled: type=cast, content=invalid_message, state=active, data=3

状態タイムアウト付きステートマシンの実装(sample3)

最後に、状態のタイムアウトを実現します。

実装

:next_state を返す際に {:state_timeout, <タイムアウト(ミリ秒)>, <イベント>} を設定したリストを含める事で状態のタイムアウトを実現できます。

状態のタイムアウトはイベントタイプ :state_timeout でハンドリングします。

lib/state_timeout_state_machine.ex
defmodule StateTimeoutStateMachine do
  use GenStateMachine

  def init(_args) do
    {:ok, :idle, 0}
  end

  def handle_event(:cast, :on, :idle, data) do
    IO.puts "*** :on, idle -> active"
    # 状態タイムアウトの設定
    actions = [{:state_timeout, 2000, :off}]
    {:next_state, :active, data + 1, actions}
  end

  def handle_event(:cast, :off, :active, data) do
    IO.puts "*** :off, active -> idle"
    {:next_state, :idle, data}
  end

  # 状態タイムアウトの処理
  def handle_event(:state_timeout, :off, :active, data) do
    IO.puts "*** :state_timeout, active -> idle"
    {:next_state, :idle, data}
  end

  def handle_event(event_type, event_content, state, data) do
    IO.puts "*** Unhandled: type=#{event_type}, content=#{event_content}, state=#{state}, data=#{data}"
    {:keep_state, data}
  end
end

動作確認の処理を実装します。

lib/sample3.ex
defmodule Sample3 do
  def main(_args) do
    {:ok, pid} = GenStateMachine.start_link(StateTimeoutStateMachine, nil)

    GenStateMachine.cast(pid, :on)
    GenStateMachine.cast(pid, :off)

    GenStateMachine.cast(pid, :off)

    GenStateMachine.cast(pid, :on)

    :timer.sleep(2500)

    GenStateMachine.cast(pid, :on)

    :timer.sleep(1500)

    GenStateMachine.cast(pid, :invalid_message)

    :timer.sleep(1500)

    GenStateMachine.cast(pid, :invalid_message)

    :timer.sleep(2500)

    GenStateMachine.stop(pid)
  end
end

ビルドと実行

依存パッケージの取得とビルド
> mix deps.get
・・・

> mix escript.build
・・・

実行結果は以下の通り、invalid_message のハンドリングとは無関係に状態のタイムアウトが機能しています。

実行結果
> escript sample3

*** :on, idle -> active
*** :off, active -> idle
*** Unhandled: type=cast, content=off, state=idle, data=1
*** :on, idle -> active
*** :state_timeout, active -> idle
*** :on, idle -> active
*** Unhandled: type=cast, content=invalid_message, state=active, data=3
*** :state_timeout, active -> idle
*** Unhandled: type=cast, content=invalid_message, state=idle, data=3

MySQL Binary Log connector でバイナリログをイベント処理

MySQL Binary Log connector (mysql-binlog-connector-java) を使うと、Java プログラムで MySQL / MariaDB のバイナリログをイベント処理できます。

そのため、MySQL の CDC(Change Data Capture)として使えるかもしれません。

ソースは http://github.com/fits/try_samples/tree/master/blog/20171030/

Groovy で実装

MySQL へ接続してバイナリログの内容を取得するには BinaryLogClient を使います。

registerEventListener メソッドで EventListener 実装オブジェクトを登録しておくと、バイナリログの内容をデシリアライズした Event オブジェクトを引数として onEvent メソッドを呼び出してくれます。

BinaryLogClient のソース(listenForEventPackets() メソッドなど)を見てみると、バイナリログを順番にデシリアライズして(特定のイベントタイプをスキップするような処理は無さそう)、登録している EventListener の onEvent を順次呼び出しているだけのようなので、場合によっては処理性能に注意が必要かもしれません。

binlog_sample.groovy
@Grab('com.github.shyiko:mysql-binlog-connector-java:0.13.0')
import com.github.shyiko.mysql.binlog.BinaryLogClient

def host = args[0]
def port = args[1] as int
def user = args[2]
def pass = args[3]

def client = new BinaryLogClient(host, port, user, pass)

client.registerEventListener { ev ->
    // バイナリログの内容を処理
    println ev
}

client.connect()

動作確認

動作確認は Docker で行ってみます。

準備

BinaryLogClient で接続するために MySQL 側でレプリケーション用の設定を行います。

まずは、レプリケーションの設定ファイルを用意します。

/home/vagrant/mysql/conf/repl.cnf (レプリケーションの設定)
[mysqld]
log-bin=mysql-bin
server-id=1

次に、レプリケーション用の接続ユーザーを追加するための SQL ファイルも用意しておきます。

/home/vagrant/mysql/init/repl-user.sqlレプリケーション用のユーザー)
GRANT REPLICATION SLAVE, REPLICATION CLIENT ON *.* TO repl@'%' IDENTIFIED BY 'pass';

今回はコンテナ間の接続(DB への接続)に Docker のユーザー定義ネットワークを使います。

そのため、まずはブリッジネットワーク(sample1)を作成しておきます。

Docker ユーザー定義ネットワークの作成
$ docker network create --driver bridge sample1

/home/vagrant/groovy/binlog_sample.groovy ファイルを用意した後、sample1 のネットワークへ参加するように Groovy のコンテナを実行します。

Groovy コンテナ実行
$ docker run --rm -it --net=sample1 -v /home/vagrant/groovy:/work groovy bash

・・・
groovy@・・・:~$ cd /work

a. MySQL 5.7 の場合

それでは、MySQL のコンテナを実行して動作確認を行います。

MySQL の Docker 公式イメージでは、/etc/mysql/conf.d 内の設定ファイルを適用し、/docker-entrypoint-initdb.d 内の SQL ファイルを実行するようになっています。

今回はこれを使って、先ほど用意しておいたレプリケーションの設定ファイルとユーザー作成 SQL を適用するように実行します。

a-1. MySQL コンテナ実行
$ docker run --name mysql1 --net=sample1 -e MYSQL_ROOT_PASSWORD=secret -d -v /home/vagrant/mysql/conf:/etc/mysql/conf.d -v /home/vagrant/mysql/init:/docker-entrypoint-initdb.d mysql

事前に実行しておいた Groovy コンテナ上で binlog_sample.groovy を実行します。 ユーザー名とパスワードはレプリケーション用のものを使用します。

a-2. binlog_sample.groovy 実行(Groovy コンテナ)
groovy@・・・:/work$ groovy binlog_sample.groovy mysql1 3306 repl pass

Oct 22, 2017 4:46:02 PM com.github.shyiko.mysql.binlog.BinaryLogClient connect
INFO: Connected to mysql1:3306 at mysql-bin.000003/154 (sid:65535, cid:3)
Event{header=EventHeaderV4{timestamp=0, eventType=ROTATE, serverId=1, headerLength=19, dataLength=28, nextPosition=0, flags=32}, data=RotateEventData{binlogFilename='mysql-bin.000003', binlogPosition=154}}
・・・

この状態で以下の SQL を実行してみます。

CREATE DATABASE db1;
USE db1;

CREATE TABLE tbl1 (id int NOT NULL PRIMARY KEY, name varchar(10) NOT NULL);

INSERT INTO tbl1 VALUES (1, 'a');
UPDATE tbl1 SET name = 'aa' WHERE id = 1;
DELETE FROM tbl1 WHERE id = 1;

CREATE TABLE tbl2 (id int NOT NULL PRIMARY KEY, name varchar(10) NOT NULL);

START TRANSACTION;

INSERT INTO tbl1 VALUES (1, 'a');
INSERT INTO tbl2 VALUES (2, 'b');

COMMIT;

上記 SQL 実行後の出力結果です。

a-3. binlog_sample.groovy 出力結果
・・・
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=ANONYMOUS_GTID, serverId=1, headerLength=19, dataLength=46, nextPosition=219, flags=0}, data=null}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=QUERY, serverId=1, headerLength=19, dataLength=72, nextPosition=310, flags=8}, data=QueryEventData{threadId=4, executionTime=0, errorCode=0, database='db1', sql='CREATE DATABASE db1'}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=ANONYMOUS_GTID, serverId=1, headerLength=19, dataLength=46, nextPosition=375, flags=0}, data=null}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=QUERY, serverId=1, headerLength=19, dataLength=127, nextPosition=521, flags=0}, data=QueryEventData{threadId=4, executionTime=0, errorCode=0, database='db1', sql='CREATE TABLE tbl1 (id int NOT NULL PRIMARY KEY, name varchar(10) NOT NULL)'}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=ANONYMOUS_GTID, serverId=1, headerLength=19, dataLength=46, nextPosition=586, flags=0}, data=null}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=QUERY, serverId=1, headerLength=19, dataLength=52, nextPosition=657, flags=8}, data=QueryEventData{threadId=4, executionTime=0, errorCode=0, database='db1', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=TABLE_MAP, serverId=1, headerLength=19, dataLength=30, nextPosition=706, flags=0}, data=TableMapEventData{tableId=219, database='db1', table='tbl1', columnTypes=3, 15, columnMetadata=0, 10, columnNullability={}}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=EXT_WRITE_ROWS, serverId=1, headerLength=19, dataLength=23, nextPosition=748, flags=0}, data=WriteRowsEventData{tableId=219, includedColumns={0, 1}, rows=[
    [1, a]
]}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=779, flags=0}, data=XidEventData{xid=14}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=ANONYMOUS_GTID, serverId=1, headerLength=19, dataLength=46, nextPosition=844, flags=0}, data=null}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=QUERY, serverId=1, headerLength=19, dataLength=52, nextPosition=915, flags=8}, data=QueryEventData{threadId=4, executionTime=0, errorCode=0, database='db1', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=TABLE_MAP, serverId=1, headerLength=19, dataLength=30, nextPosition=964, flags=0}, data=TableMapEventData{tableId=219, database='db1', table='tbl1', columnTypes=3, 15, columnMetadata=0, 10, columnNullability={}}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=EXT_UPDATE_ROWS, serverId=1, headerLength=19, dataLength=32, nextPosition=1015, flags=0}, data=UpdateRowsEventData{tableId=219, includedColumnsBeforeUpdate={0, 1}, includedColumns={0, 1}, rows=[
    {before=[1, a], after=[1, aa]}
]}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=1046, flags=0}, data=XidEventData{xid=15}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=ANONYMOUS_GTID, serverId=1, headerLength=19, dataLength=46, nextPosition=1111, flags=0}, data=null}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=QUERY, serverId=1, headerLength=19, dataLength=52, nextPosition=1182, flags=8}, data=QueryEventData{threadId=4, executionTime=0, errorCode=0, database='db1', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=TABLE_MAP, serverId=1, headerLength=19, dataLength=30, nextPosition=1231, flags=0}, data=TableMapEventData{tableId=219, database='db1', table='tbl1', columnTypes=3, 15, columnMetadata=0, 10, columnNullability={}}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=EXT_DELETE_ROWS, serverId=1, headerLength=19, dataLength=24, nextPosition=1274, flags=0}, data=DeleteRowsEventData{tableId=219, includedColumns={0, 1}, rows=[
    [1, aa]
]}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=1305, flags=0}, data=XidEventData{xid=16}}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=ANONYMOUS_GTID, serverId=1, headerLength=19, dataLength=46, nextPosition=1370, flags=0}, data=null}
Event{header=EventHeaderV4{timestamp=1508690815000, eventType=QUERY, serverId=1, headerLength=19, dataLength=127, nextPosition=1516, flags=0}, data=QueryEventData{threadId=4, executionTime=1, errorCode=0, database='db1', sql='CREATE TABLE tbl2 (id int NOT NULL PRIMARY KEY, name varchar(10) NOT NULL)'}}
Event{header=EventHeaderV4{timestamp=1508690819000, eventType=ANONYMOUS_GTID, serverId=1, headerLength=19, dataLength=46, nextPosition=1581, flags=0}, data=null}
Event{header=EventHeaderV4{timestamp=1508690816000, eventType=QUERY, serverId=1, headerLength=19, dataLength=52, nextPosition=1652, flags=8}, data=QueryEventData{threadId=4, executionTime=0, errorCode=0, database='db1', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508690816000, eventType=TABLE_MAP, serverId=1, headerLength=19, dataLength=30, nextPosition=1701, flags=0}, data=TableMapEventData{tableId=219, database='db1', table='tbl1', columnTypes=3, 15, columnMetadata=0, 10, columnNullability={}}}
Event{header=EventHeaderV4{timestamp=1508690816000, eventType=EXT_WRITE_ROWS, serverId=1, headerLength=19, dataLength=23, nextPosition=1743, flags=0}, data=WriteRowsEventData{tableId=219, includedColumns={0, 1}, rows=[
    [1, a]
]}}
Event{header=EventHeaderV4{timestamp=1508690816000, eventType=TABLE_MAP, serverId=1, headerLength=19, dataLength=30, nextPosition=1792, flags=0}, data=TableMapEventData{tableId=220, database='db1', table='tbl2', columnTypes=3, 15, columnMetadata=0, 10, columnNullability={}}}
Event{header=EventHeaderV4{timestamp=1508690816000, eventType=EXT_WRITE_ROWS, serverId=1, headerLength=19, dataLength=23, nextPosition=1834, flags=0}, data=WriteRowsEventData{tableId=220, includedColumns={0, 1}, rows=[
    [2, b]
]}}
Event{header=EventHeaderV4{timestamp=1508690819000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=1865, flags=0}, data=XidEventData{xid=19}}

eventType を簡単にまとめると以下のようになりました。

SQL eventType
CREATE QUERY
INSERT EXT_WRITE_ROWS
UPDATE EXT_UPDATE_ROWS
DELETE EXT_DELETE_ROWS

注意点として、EXT_XXX_ROWS の Event 内容にテーブル名は含まれておらず tableId で判断する必要がありそうです。 tableId とテーブル名のマッピングは直前の TABLE_MAP で実施されています。

また、バイナリログのフォーマット(以下)は RBR(行ベースレプリケーション) となっていました。

mysql> show global variables like 'binlog_format';
+---------------+-------+
| Variable_name | Value |
+---------------+-------+
| binlog_format | ROW   |
+---------------+-------+

b. MariaDB 10.2 の場合

ついでに、MariaDB でも試してみます。

設定は MySQL と同じものが使えるので、ここでは Docker イメージ名を mariadb に変えて実行するだけです。(以下ではコンテナ名も変えています)

b-1. MariaDB コンテナ実行
$ docker run --name mariadb1 --net=sample1 -e MYSQL_ROOT_PASSWORD=secret -d -v /home/vagrant/mysql/conf:/etc/mysql/conf.d -v /home/vagrant/mysql/init:/docker-entrypoint-initdb.d mariadb

接続先を mariadb1 (MariaDB) へ変えてスクリプトを実行します。

b-2. binlog_sample.groovy 実行(Groovy コンテナ)
groovy@・・・:/work$ groovy binlog_sample.groovy mariadb1 3306 repl pass

・・・

MySQL と同じ SQL を実行した後の出力結果です。

b-3. binlog_sample.groovy 出力結果
・・・
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=23, nextPosition=384, flags=8}, data=QueryEventData{threadId=0, executionTime=0, errorCode=0, database='', sql='# Dum'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=66, nextPosition=469, flags=8}, data=QueryEventData{threadId=10, executionTime=0, errorCode=0, database='db1', sql='CREATE DATABASE db1'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=23, nextPosition=511, flags=8}, data=QueryEventData{threadId=0, executionTime=0, errorCode=0, database='', sql='# Dum'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=121, nextPosition=651, flags=0}, data=QueryEventData{threadId=10, executionTime=0, errorCode=0, database='db1', sql='CREATE TABLE tbl1 (id int NOT NULL PRIMARY KEY, name varchar(10) NOT NULL)'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=23, nextPosition=693, flags=8}, data=QueryEventData{threadId=0, executionTime=0, errorCode=0, database='', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=79, nextPosition=791, flags=0}, data=QueryEventData{threadId=10, executionTime=0, errorCode=0, database='db1', sql='INSERT INTO tbl1 VALUES (1, 'a')'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=822, flags=0}, data=XidEventData{xid=12}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=23, nextPosition=864, flags=8}, data=QueryEventData{threadId=0, executionTime=0, errorCode=0, database='', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=87, nextPosition=970, flags=0}, data=QueryEventData{threadId=10, executionTime=0, errorCode=0, database='db1', sql='UPDATE tbl1 SET name = 'aa' WHERE id = 1'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=1001, flags=0}, data=XidEventData{xid=13}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=23, nextPosition=1043, flags=8}, data=QueryEventData{threadId=0, executionTime=0, errorCode=0, database='', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=76, nextPosition=1138, flags=0}, data=QueryEventData{threadId=10, executionTime=0, errorCode=0, database='db1', sql='DELETE FROM tbl1 WHERE id = 1'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=1169, flags=0}, data=XidEventData{xid=14}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=23, nextPosition=1211, flags=8}, data=QueryEventData{threadId=0, executionTime=0, errorCode=0, database='', sql='# Dum'}}
Event{header=EventHeaderV4{timestamp=1508691100000, eventType=QUERY, serverId=1, headerLength=19, dataLength=121, nextPosition=1351, flags=0}, data=QueryEventData{threadId=10, executionTime=1, errorCode=0, database='db1', sql='CREATE TABLE tbl2 (id int NOT NULL PRIMARY KEY, name varchar(10) NOT NULL)'}}
Event{header=EventHeaderV4{timestamp=1508691101000, eventType=QUERY, serverId=1, headerLength=19, dataLength=23, nextPosition=1393, flags=8}, data=QueryEventData{threadId=0, executionTime=0, errorCode=0, database='', sql='BEGIN'}}
Event{header=EventHeaderV4{timestamp=1508691101000, eventType=QUERY, serverId=1, headerLength=19, dataLength=79, nextPosition=1491, flags=0}, data=QueryEventData{threadId=10, executionTime=0, errorCode=0, database='db1', sql='INSERT INTO tbl1 VALUES (1, 'a')'}}
Event{header=EventHeaderV4{timestamp=1508691101000, eventType=QUERY, serverId=1, headerLength=19, dataLength=79, nextPosition=1589, flags=0}, data=QueryEventData{threadId=10, executionTime=0, errorCode=0, database='db1', sql='INSERT INTO tbl2 VALUES (2, 'b')'}}
Event{header=EventHeaderV4{timestamp=1508691101000, eventType=XID, serverId=1, headerLength=19, dataLength=12, nextPosition=1620, flags=0}, data=XidEventData{xid=17}}

eventType を簡単にまとめると以下のようになりました。

SQL eventType
CREATE QUERY
INSERT QUERY
UPDATE QUERY
DELETE QUERY

バイナリログのフォーマットは MBRミックスベースレプリケーション)となっていました。

MariaDB [(none)]> show global variables like 'binlog_format';
+---------------+-------+
| Variable_name | Value |
+---------------+-------+
| binlog_format | MIXED |
+---------------+-------+