Spark を使って単純なレコメンドを実施

分散処理フレームワークSpark を使って、id:fits:20111123 で実施したような GitHub データの簡単なレコメンドを実施してみます。

Spark はインメモリーに分散データをキャッシュできる等の特徴があり、個人的に Scala のコレクション API 風に MapReduce 処理を実装できる点が気に入っています。

サンプルソースは http://github.com/fits/try_samples/tree/master/blog/20111215/

なお、入力データは id:fits:20111123 で使った CSV ファイル(以下のフォーマット)をそのまま使う事にします。

入力データ例(grails_watcher_watched.csv
・・・
261649,fits,158886,grails,https://github.com/grails/grails
261649,fits,108110,mongo,https://github.com/mongodb/mongo
261649,fits,2404027,storm,https://github.com/nathanmarz/storm
・・・

事前準備

今のところ Spark を使うにはソースからビルドするしかなさそうなので、https://github.com/mesos/spark からソースを取得し、sbt を使ってビルドします。

sbt assembly を実行すると必要なライブラリをパッケージングした単一の JAR ファイルが生成されますので、これをクラスパスに追加して使います。(core/target/spark-core-assembly-0.4-SNAPSHOT.jar)

なお、sbt の JAR ファイルと実行用のシェルスクリプトが Spark ソースの sbt ディレクトリに用意されているのですが、Windows 環境用のスクリプトは用意されていないので自前で用意する必要があります。

ビルド例
> sbt assembly

概要

レコメンド処理の実用性はあまり考えず、おおまかに以下のような手順でレコメンド処理を実施してみました。

  1. ターゲットユーザーと同じリポジトリを一定数以上 watch しているユーザーを抽出
  2. 抽出したユーザーが watch していてターゲットユーザーが watch していないリポジトリを抽出
  3. 抽出したリポジトリを watch しているユーザー数で集計後、ソートして上位を抽出

実装

実装内容は以下の通りです。

Spark ではデータセットを抽象化した RDD クラスとそのサブクラスに対して処理を行っていきます。

  1. リポジトリ毎に集計
    • map で CSV の内容をからリポジトリ名とユーザー名・評価値 ※ の組み合わせデータを作成 *1
    • groupByKey でリポジトリ毎にユーザーをグルーピング
      • (リポジトリ名, ArrayBuffer( (ユーザー名, 評価値), ・・・) )
    • mapValues でターゲットユーザーが watch しているリポジトリに対してはユーザー毎に評価値の差をとり *2 、watch していないリポジトリは None を設定
      • (リポジトリ名, ArrayBuffer( (ユーザー名, 評価値の差), ・・・) )
      • (リポジトリ名, ArrayBuffer( (ユーザー名, None), ・・・) )
  2. ユーザー毎に集計
    • flatMap と map でユーザーをキーにしたデータに変更
    • groupByKey でユーザー毎にリポジトリをグルーピング
      • (ユーザー名, ArrayBuffer( (リポジトリ名, 評価値の差 or None), ・・・) )
    • mapValues で評価値の差を使ったスコア算出 ※
      • (ユーザー名, (スコア, ArrayBuffer( (リポジトリ名, 評価値の差 or None), ・・・) ) )
  3. ターゲットユーザー以外でスコアが 7 より大きいユーザーを抽出
  4. 抽出した中でターゲットユーザーが watch していないリポジトリの watch ユーザー数をカウント
  5. watch ユーザー数の多い順にソートして上位 5件を抽出して出力 *3

※ 今回のケースでは、評価値なんて必要無くて差を算出する意味は無いのですが、評価値付きのデータにも対応し易いようにこのような実装にしてみました

simple_github_recommend.scala
import scala.math._
import spark._
import SparkContext._

object SimpleGitHubRecommend {
    def main(args: Array[String]) {
        if (args.length < 3) {
            println("parameters: <host> <data file> <target user>")
            return
        }

        //第一引数でローカル実行か Mesos によるクラスタ実行かを指定
        val spark = new SparkContext(args(0), "SimpleGitHubRecommend")

        val file = spark.textFile(args(1))
        val targetUser = args(2)

        // 1. リポジトリ毎に集計
        val itemsRes = file.map { l =>
            val fields = l.split(",")
            //評価値は 1.0 を設定
            (fields(3), (fields(1), 1.0))
        }.groupByKey().mapValues { users =>
            users.find(_._1 == targetUser) match {
                //ターゲットユーザーを含む場合は評価値の差を設定
                case Some((_, targetPoint)) => users.map { case (user, point) =>
                    (user, abs(point - targetPoint))
                }
                //ターゲットユーザーを含まない場合は None を設定
                case None => users.map { case (user, point) => (user, None) }
            }
        }

        // 2. ユーザー単位の集計
        val usersRes = itemsRes.flatMap { case (item, users) =>
            users.map { case (user, point) => (user, (item, point)) }
        }.groupByKey().mapValues { items =>
            //評価値の差を使ってスコア算出
            val score = items.foldLeft(0.0) { (subTotal, itemTuple) =>
                itemTuple._2 match {
                    case p: Double => subTotal + 1.0 / (1.0 + p)
                    case None => subTotal
                }
            }
            (score, items)
        }

        // 3. ターゲットユーザー以外でスコアが 7 より大きいユーザーを抽出
        val pickupRes = usersRes.filter { case (user, (score, _)) =>
            score > 7 && user != targetUser
        }

        // 4. 抽出されたユーザーの中でターゲットユーザーが watch していない
        //リポジトリをカウント
        val res = pickupRes.flatMap { case (user, (score, items)) =>
            items.filter(_._2 == None).map { case (item, _) =>
                (item, 1)
            }
        }.reduceByKey(_ + _)

        //sortBy のソート順を設定
        implicit val order = Ordering.Int.reverse

        // 5. watch しているユーザー数の多い上位 5件のリポジトリを抽出して出力
        res.collect.sortBy(_._2).take(5).foreach { case (item, num) =>
            printf("%s : %d\n", item, num)
        }
    }
}

SparkContext コンストラクタの第一引数でローカル実行(local または local[N])、もしくは Mesos を使ったクラスタ実行を指定できます。

local[N] はローカル実行時のワーカースレッドを複数(N 個)指定する場合の指定方法です。

実行

今回はローカル実行のみ試してみました。

単一のワーカースレッドでローカル実行
> scala -cp spark-core-assembly-0.4-SNAPSHOT.jar simple_github_recommend.scala local grails_watcher_watched.csv fits
・・・
11/12/15 23:22:01 INFO spark.LocalScheduler: Finished task 0
11/12/15 23:22:01 INFO spark.LocalScheduler: Completed ResultTask(0, 0)
11/12/15 23:22:01 INFO spark.SparkContext: Job finished in 2.433545777 s
dotfiles : 20
jquery : 13
jekyll : 11
sproutcore : 10
bootstrap : 9
4つのワーカースレッドでローカル実行
> scala -cp spark-core-assembly-0.4-SNAPSHOT.jar simple_github_recommend.scala local[4] grails_watcher_watched.csv fits
・・・
11/12/15 23:22:20 INFO spark.LocalScheduler: Finished task 3
・・・
11/12/15 23:22:20 INFO spark.LocalScheduler: Completed ResultTask(0, 3)
・・・
11/12/15 23:22:20 INFO spark.LocalScheduler: Finished task 2
11/12/15 23:22:20 INFO spark.LocalScheduler: Completed ResultTask(0, 2)
11/12/15 23:22:20 INFO spark.LocalScheduler: Finished task 0
11/12/15 23:22:20 INFO spark.LocalScheduler: Completed ResultTask(0, 0)
11/12/15 23:22:20 INFO spark.LocalScheduler: Finished task 1
11/12/15 23:22:20 INFO spark.LocalScheduler: Completed ResultTask(0, 1)
11/12/15 23:22:20 INFO spark.SparkContext: Job finished in 2.866628068 s
dotfiles : 20
jquery : 13
jekyll : 11
sproutcore : 10
couchdb : 9

*1:今回の CSV データに評価値が無いので 1.0 を設定しています

*2:ユーザーの評価値からターゲットユーザーの評価値を引いた値の絶対値、今回は常に 0 になります

*3:RDD にソート用のメソッドが無かったため、collect を使って Array 化しました