Apache Spark でロジスティック回帰

以前 ※ に R や Julia で試したロジスティック回帰を Apache Spark の MLlib (Machine Learning Library) を使って実施してみました。

サンプルソースhttp://github.com/fits/try_samples/tree/master/blog/20150427/

※「 R でロジスティック回帰 - glm, MCMCpack 」、「 Julia でロジスティック回帰-glm

はじめに

R の時と同じデータを使いますが、ヘッダー行を削除しています。(「R でロジスティック回帰 - glm, MCMCpack」 参照)

データ data4a.csv
8,1,9.76,C
8,6,10.48,C
8,5,10.83,C
・・・

データ内容は以下の通り。個体 i それぞれにおいて 「 { N_i } 個の観察種子のうち生きていて発芽能力があるものは { y_i } 個」 となっています。

項目 内容
N 観察種子数
y 生存種子数
x 植物の体サイズ
f 施肥処理 (C: 肥料なし, T: 肥料あり)

体サイズ x と肥料による施肥処理 f が種子の生存する確率(ある個体 i から得られた種子が生存している確率)にどのように影響しているかをロジスティック回帰で解析します。

MLlib によるロジスティック回帰

今回は org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS を使用します。

LogisticRegressionWithLBFGS について

LogisticRegressionWithLBFGS で以前と同様のロジスティック回帰を実施するには以下が必要です。

  • setIntercept で true を設定

この値が false (デフォルト値) の場合、結果の intercept 値が 0 になります。

なお、今回のように二項分布を使う場合は numClasses の値を変更する必要はありませんが (デフォルト値が 2 のため)、応答変数が 3状態以上の多項分布を使う場合は setNumClasses で状態数に応じた値を設定します。

LabeledPoint について

LogisticRegressionWithLBFGS へ与えるデータは LabeledPoint で用意します。

R や Julia では 応答変数 ~ 説明変数1 + 説明変数2 + ・・・ のように応答変数と説明変数を指定しましたが、LabeledPoint では下記のようにメンバー変数で表現します。

メンバー変数 応答変数・説明変数
label 応答変数
features 説明変数

値は Double とする必要がありますので、f 項目のような文字列値は数値化します。

更に、二項分布を使う場合 (numClasses = 2) は応答変数の値が 0 か 1 でなければなりません。

LabeledPoint への変換例

例えば、以下のようなデータを応答変数 y 項目、説明変数 x と f 項目の LabeledPoint へ変換する場合

変換前のデータ (N = 8, y = 6)
8,6,10.48,C

次のようになります。

変換後のデータイメージ
LabeledPoint(label: 1.0, features: Vector(10.48, 0.0))
LabeledPoint(label: 1.0, features: Vector(10.48, 0.0))
LabeledPoint(label: 1.0, features: Vector(10.48, 0.0))
LabeledPoint(label: 1.0, features: Vector(10.48, 0.0))
LabeledPoint(label: 1.0, features: Vector(10.48, 0.0))
LabeledPoint(label: 1.0, features: Vector(10.48, 0.0))
LabeledPoint(label: 0.0, features: Vector(10.48, 0.0))
LabeledPoint(label: 0.0, features: Vector(10.48, 0.0))

8個(N)の中で 6個(y)生存していたデータのため、 label (応答変数) の値が 1.0 (生存) のデータ 6個と 0.0 のデータ 2個へ変換します。

ちなみに、f 項目の値が C の場合は 0.0、T の場合は 1.0 としています。

実装

実装してみると以下のようになります。

LogisticRegression.scala
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors

object LogisticRegression extends App {
    // f項目の値を数値へ変換
    val factor = (s: String) => s match {
        case "C" => 0
        case _ => 1
    }

    val sc = new SparkContext("local", "LogisticRegression")

    // データの準備 (100行のデータ -> 800個の LabeledPoint)
    val rdd = sc.textFile(args(0)).map(_.split(",")).flatMap { d =>
        val n = d(0).toInt
        val x = d(1).toInt
        // 説明変数の値
        val v = Vectors.dense(d(2).toDouble, factor(d(3)))

        // 応答変数が 1 のデータ x 個と 0 のデータ n - x 個を作成
        List.fill(x)( LabeledPoint(1, v) ) ++ 
            List.fill(n -x)( LabeledPoint(0, v) )
    }

    // ロジスティック回帰の実行
    val res = new LogisticRegressionWithLBFGS()
//      .setNumClasses(2) //省略可
        .setIntercept(true)
        .run(rdd)

    println(res)
}

ビルド

以下のような Gradle ビルド定義ファイルを使って実行します。

build.gradle
apply plugin: 'scala'
apply plugin: 'application'

mainClassName = 'LogisticRegression'

repositories {
    jcenter()
}

dependencies {
    compile 'org.scala-lang:scala-library:2.11.6'

    compile('org.apache.spark:spark-mllib_2.11:1.3.1') {
        // ログ出力の抑制
        exclude module: 'slf4j-log4j12'
    }

    // ログ出力の抑制
    runtime 'org.slf4j:slf4j-nop:1.7.12'
}

run {
    if (project.hasProperty('args')) {
        args project.args.split(' ')
    }
}

不要な WARN ログ出力を抑制するため以下のファイルも用意しました。

src/main/resources/log4j.properties
log4j.rootLogger=off

実行

実行結果は以下の通りです。

実行結果
> gradle run -Pargs=data4a.csv

:clean
:compileJava UP-TO-DATE
:compileScala
:processResources
:classes
:run
(weights=[1.952347703282676,2.021401680901667], intercept=-19.535421113192506)

BUILD SUCCESSFUL

以前に実施した R の結果 (Estimate の値) とほとんど同じ値になっています。

R の glm 関数による結果
Coefficients:
            Estimate Std. Error z value Pr(>|z|)    
(Intercept) -19.5361     1.4138  -13.82   <2e-16 ***
x             1.9524     0.1389   14.06   <2e-16 ***
fT            2.0215     0.2313    8.74   <2e-16 ***

Spark SQL で CSV ファイルを処理2 - GeoLite2

前回の 「Spark SQL で CSV ファイルを処理 - GeoLite Legacy」 に続き、今回は Spark SQL を使って GeoLite2 City CSV ファイルを処理してみます。

今回のソースは http://github.com/fits/try_samples/tree/master/blog/20141112/

はじめに

GeoLite2 City の CSV は下記のような 2種類のファイルで構成しています。

  • GeoLite2-City-Blocks.csv (IP と都市情報とのマッピング
  • GeoLite2-City-Locations.csv (国・都市情報)

GeoLite2-City-Blocks.csv で IP アドレスから geoname_id を割り出し、GeoLite2-City-Locations.csv で geoname_id から国・都市を特定します。

ファイルの内容は下記のようになっており、IP は IPv6 の形式で記載されています。

GeoLite2-City-Blocks.csv の例
network_start_ip,network_prefix_length,geoname_id,registered_country_geoname_id,represented_country_geoname_id,postal_code,latitude,longitude,is_anonymous_proxy,is_satellite_provider
・・・
::ffff:1.0.64.0,114,1862415,1861060,,,・・・
・・・
2602:30a:2c1d::,48,5368361,,,・・・
・・・
GeoLite2-City-Locations.csv の例
geoname_id,continent_code,continent_name,country_iso_code,country_name,subdivision_iso_code,subdivision_name,city_name,metro_code,time_zone
1862415,AS,Asia,JP,Japan,34,Hiroshima,・・・
・・・

Spark SQL を使って IP アドレスから都市判定

GeoLite Legacy の Country CSV を処理した前回との違いは、下記 2点です。

  • (1) GeoLite2-City-Blocks.csv と GeoLite2-City-Locations.csv の 2つの CSV を geoname_id で join する
  • (2) network_start_ip と network_prefix_length を使って IP アドレスの数値の範囲を算出する

(1) は前回と同様に CSV を処理して SQL で join するだけです。 (2) は下記のようにして求める事ができます。

  • (a) IP アドレスの開始値は network_start_ip を数値化
  • (b) IP アドレスの終了値は (a) の値の下位 128 - network_prefix_length ビットを全て 1 とした値

今回は IPv4 のみを対象とするため、GeoLite2-City-Blocks.csv::ffff: で始まる行だけを使って (::ffff: 以降がそのまま IPv4 に該当)、上記 (a) と (b) の処理を実装してみました。

注意点として、GeoLite2-City-Locations.csv には subdivision_iso_code 以降が全て空欄のデータも含まれていました。 (例えば 2077456,OC,Oceania,AU,Australia,,,,,split(",") すると Array(2077456, OC, Oceania, AU, Australia) となってしまいます)

GetCity.scala
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext

import java.net.InetAddress

// GeoLite2-City-Blocks.csv 用のスキーマ定義
case class IpMapping(startIpNum: Long, endIpNum: Long, geonameId: String)
// GeoLite2-City-Locations.csv 用のスキーマ定義
case class City(geonameId: String, country: String, city: String)

object GetCity extends App {
    if (args.length < 1) {
        println("<ip address>")
        System.exit(0)
    }

    // IPv4 の数値変換
    val toIpNum = (ip: String) => Integer.toUnsignedLong(InetAddress.getByName(ip).hashCode())

    val locationFile = "GeoLite2-City-Locations.csv"
    val blockFile = "GeoLite2-City-Blocks.csv"

    val sc = new SparkContext("local", "GetCity")

    val sqlContext = new SQLContext(sc)

    import sqlContext.createSchemaRDD

    val locations = sc.textFile(locationFile).map(_.split(",")).map { r =>
        // City 情報の無いデータ(subdivision_iso_code 以降が空欄)への対処
        val city = if (r.length > 7) r(7) else ""
        City(r(0), r(4), city)
    }

    locations.registerTempTable("locations")

    // IPv4 のみ (::ffff: で始まるもの) を対象
    val blocks = sc.textFile(blockFile).filter(_.startsWith("::ffff:")).map(_.split(",")).map { r =>
        val mask = -1 << (128 - r(1).toInt)
        // (a)
        val startIpNum = toIpNum(r(0).replaceAll("::ffff:", ""))
        // (b)
        val endIpNum = startIpNum | ~mask

        IpMapping(startIpNum, endIpNum, r(2))
    }

    blocks.registerTempTable("blocks")

    val ipNum = toIpNum(args(0))

    val rows = sqlContext.sql(s"""
        select
            city,
            country
        from
            locations lo
            join blocks bl on
                bl.geonameId = lo.geonameId
        where
            startIpNum <= ${ipNum} and
            endIpNum >= ${ipNum}
    """)

    rows.foreach( r => println(s"${r(0)}, ${r(1)}") )
}

上記では、IP の終了値 (b) を算出するために、上位ビットを 1、下位ビットを 0 にした mask を作成し、これをビット反転して開始値 (a) と論理和をとっています。

例えば、network_start_ip が ::ffff:1.0.64.0 で network_prefix_length が 114 のデータの場合、(a) の値は 1.0.64.0 を数値化して 16793600、mask 変数の値は 2進数で ・・・111100000000000000、(b) の値は mask 変数の値をビット反転した 011111111111111 と (a) の値との論理和16809983 となり、16793600 ~ 16809983 の範囲内にある IP アドレスが該当する事になります。

実行 (Gradle 利用)

  • Gradle 2.1

前回と同様に Gradle で実行します。
slf4j-nop を使って Spark の標準的なログ出力を抑制している点も同じです。

build.gradle
apply plugin: 'application'
apply plugin: 'scala'

repositories {
    mavenCentral()
}

dependencies {
    compile 'org.scala-lang:scala-library:2.10.4'
    compile('org.apache.spark:spark-sql_2.10:1.1.0') {
        exclude module: 'slf4j-log4j12'
    }
    runtime 'org.slf4j:slf4j-nop:1.7.7'
}

mainClassName = 'GetCity'

run {
    if (project.hasProperty('args')) {
        args project.args.split(' ')
    }
}
実行結果1
> gradle run -q -Pargs=1.21.127.254

Tokyo, Japan
実行結果2
> gradle run -q -Pargs=223.255.254.1

, Singapore

Spark SQL で CSV ファイルを処理 - GeoLite Legacy

以前、H2 を使って CSV ファイルを SQL で処理しましたが、今回は Spark SQL を使ってみました。

IPアドレスから地域を特定する2 - GeoLite Legacy Country CSV」 で使った GeoLite Legacy Country CSV を使って同様の処理を Spark SQL で実装します。

今回のソースは http://github.com/fits/try_samples/tree/master/blog/20141103-2/

Spark SQL を使って IP アドレスから国判定

Spark SQL で扱うテーブルのスキーマを定義する方法はいくつか用意されているようですが、今回はケースクラスをスキーマとして登録する方法で実装しました。

処理の手順は下記のようになります。

  • (1) スキーマ用のクラス定義
  • (2) CSV ファイルを処理して RDD 作成
  • (3) テーブル登録
  • (4) SQL の実行

(2) の処理で (1) のケースクラスを格納した RDD を作成し、(3) の処理で (2) で処理したオブジェクトをテーブルとして登録します。

(2) の処理までは通常の Spark の API を使った処理ですが、import sqlContext.createSchemaRDD によって (3) で registerTempTable メソッドを呼び出す際に RDD から Spark SQLSchemaRDD へ暗黙変換が実施されます。

registerTempTable の引数としてテーブル名を渡す事で、SQL 内でこのテーブル名を使用できるようになります。

そのあとは SQL を実行して結果を出力するだけです。

foreach の要素となる org.apache.spark.sql.Row の実体は org.apache.spark.sql.catalyst.expressions.Row トレイトで、このトレイトが Seq トレイトを extends しているため head などの Seq の API も使えます。

GetCountry.scala
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext

import java.net.InetAddress

// (1) スキーマ用のクラス定義
case class IpCountry(startIpNum: Long, endIpNum: Long, countryName: String)

object GetCountry extends App {
    if (args.length < 1) {
        println("<ip address>")
        System.exit(0)
    }

    val countryFile = "GeoIPCountryWhois.csv"

    val sc = new SparkContext("local", "GetCountry")

    val sqlContext = new SQLContext(sc)

    // RDD を SchemaRDD へ暗黙変換するための定義
    import sqlContext.createSchemaRDD

    // (2) CSV ファイルを処理して RDD 作成
    val countries = sc.textFile(countryFile).map(_.replaceAll("\"", "").split(",")).map { d =>
        IpCountry(d(2).toLong, d(3).toLong, d(5))
    }
    // (3) テーブル登録
    countries.registerTempTable("countries")

    val ipNum = Integer.toUnsignedLong( InetAddress.getByName(args(0)).hashCode )
    // (4) SQL 実行
    val rows = sqlContext.sql(s"""
        select
            countryName
        from
            countries
        where
            startIpNum <= ${ipNum} and
            endIpNum >= ${ipNum}
    """)

    rows.foreach( r => println(r.head) )
}

実行 (Gradle 利用)

  • Gradle 2.1

今回は Gradle で実行するため、下記のようなビルド定義ファイルを用意しました。

現時点では、Maven のセントラルリポジトリScala 2.11 用の Spark SQL の JAR ファイルは用意されていないようなので、Scala 2.10.4 を使います。

今回の用途では Spark の標準的なログ出力が邪魔だったので slf4j-log4j12 の代わりに slf4j-nop を使うようにしてログ出力を抑制しました。

build.gradle
apply plugin: 'application'
apply plugin: 'scala'

repositories {
    mavenCentral()
}

dependencies {
    compile 'org.scala-lang:scala-library:2.10.4'
    compile('org.apache.spark:spark-sql_2.10:1.1.0') {
        // Spark のログ出力を抑制
        exclude module: 'slf4j-log4j12'
    }
    runtime 'org.slf4j:slf4j-nop:1.7.7'
}

mainClassName = 'GetCountry'

run {
    if (project.hasProperty('args')) {
        // コマンドライン引数の設定
        args project.args.split(' ')
    }
}

更に、Gradle のログ出力 (タスクの実行経過) も抑制したいので、-q オプションを使って実行しました。

実行結果1
> gradle run -q -Pargs=1.21.127.254

Japan
実行結果2
> gradle run -q -Pargs=223.255.254.1

Singapore