Scalaz で継続モナド

以前(id:fits:20121104)、Haskell で実装した継続モナドのサンプルを Scalaz で実装してみました。

なお、今のところ Scalaz に継続モナドは用意されていないようで、id:fits:20121111 のような方法で自作する必要がありました。

ただし、実際のところ Scala で継続モナドが必要となるケースは基本的に無さそうな気がします。


サンプルソースは http://github.com/fits/try_samples/tree/master/blog/20121125/


今回使用した sbt 用のビルドファイルは以下です。

build.sbt
scalaVersion := "2.10.0-RC2"

libraryDependencies += "org.scalaz" % "scalaz-core" % "7.0.0-M4" cross CrossVersion.full

継続モナドの実装

はじめに、継続モナドを実装します。
Haskell の実装を参考に Cont モナドと callCC を定義してみました。

Cont.scala (継続モナドの定義)
package fits.sample

import scalaz._
import Scalaz._

// モナドとして扱う型の定義
case class Cont[R, A](runCont: (A => R) => R)
// 継続モナドの注入関数・連鎖関数を実装
trait ContMonad[R] extends Monad[({type r[a] = Cont[R, a]})#r] {
    // 注入関数の実装
    def point[A](a: => A) = Cont { k => k(a) }
    // 連鎖関数の実装
    def bind[A, B](fa: Cont[R, A])(f: (A) => Cont[R, B]) = {
        Cont { k =>
            fa.runCont { a =>
                f(a).runCont(k)
            }
        }
    }
}

trait ContFunctions {
    // callCC の実装
    def callCC[R, A, B](f: (A => Cont[R, B]) => Cont[R, A]): Cont[R, A] = {
        Cont { k =>
            f { a =>
                Cont { _ => k(a) }
            }.runCont(k)
        }
    }
}

trait ContInstances {
    // 継続モナドのインスタンス定義
    implicit def contInstance[R] = new ContMonad[R] {
    }
}

case object Cont extends ContFunctions with ContInstances

バインド >>= の処理

id:fits:20121104 と同様にバインド関数 >>= を使った簡単なサンプルを書いてみました。

Sample.scala (バインド処理のサンプル)
package fits.sample

import scalaz._
import Scalaz._

object Sample extends App {
    import Cont._
    // 継続モナドの作成
    def cont[R](a: Int) = contInstance[R].point(a)

    def calc1[R](x: Int) = cont[R](x + 3)

    def calc2[R](x: Int) = cont[R](x * 10)

    def calc3[R](x: Int) = cont[R](x + 4)

    def calcAll[R](x: Int) = cont[R](x) >>= calc1 >>= calc2 >>= calc3

    calc1(2).runCont { println } // a. 2 + 3 = 5

    calcAll(2).runCont { println } // b. ((2 + 3) * 10) + 4 = 54

    calcAll(2).runCont { x => x - 9 } |> println // c. 54 - 9 = 45
}

a. は calc1(2) で得た継続モナドから runCont で取り出した値 (継続渡し形式の処理) に継続 (println 関数) を渡しており、x(=2) + 3 の結果を引数として println が実行される事になります。

b. は calcAll(2) で得た継続モナドから取り出した値に println を渡しており、以下のような処理が実施されます。

  1. 2 を引数に calc1 実行
  2. calc1 の結果を引数に calc2 実行
  3. calc2 の結果を引数に calc3 実行
  4. calc3 の結果を引数に println 実行

>>= で処理を繋げる事で、処理結果を次の処理に渡していき最終的に継続 (a. や b. における println) を呼び出すような処理を実装できます。

c. は println の代わりに x - 9 という処理を継続として渡しており、これによって継続渡し形式の処理結果が 54 - 9 = 45 となります。

実行結果
> sbt console
・・・
scala> fits.sample.Sample.main(null)
5
54
45

callCC の処理1 (簡易版)

callCC を使った単純なサンプルを実装します。

CallCCSample1.scala (callCC の簡易版サンプル1)
package fits.sample

import scalaz._
import Scalaz._

object CallCCSample1 extends App {
    import Cont._

    def sample[R](n: Int): Cont[R, Int] = callCC { cc: (Int => Cont[R, Int]) =>
        if(n % 2 == 1) {
            cc(n) // (1)
        }
        else {
            contInstance[R].point(n * 10) // (2)
        }
    }

    sample(1).runCont { println } // (1)
    sample(2).runCont { println } // (2)
    sample(3).runCont { println } // (1)
    sample(4).runCont { println } // (2)
}

sample の処理内容は以下のようになっています。

  • 引数 n が奇数なら n の値を継続に適用する継続モナドを返す (1)
  • 引数 n が偶数なら n * 10 の値を継続に適用する継続モナドを返す (2)

実行結果は以下の通りです。
奇数ならそのままの値、偶数なら 10 倍した値が出力されます。

実行結果
scala> fits.sample.CallCCSample1.main(null)
1
20
3
40

callCC の処理1a (Haskell のサンプルに近づけた版)

実は、上記 callCC のサンプルは id:fits:20121104 の Haskell 版と実装方法が結構違っていましたので、when 等も定義して近づけてみました。(かなり分かり難くなってしまいましたが)

CallCCSample1a.scala (callCC のHaskell近似版サンプル1)
package fits.sample

import scalaz._
import Scalaz._

object CallCCSample1a extends App {
    import Cont._

    val odd = (n: Int) => n % 2 == 1

    def when[M[_], A](cond: Boolean)(f: => M[A])(implicit M: Pointed[M]) = 
        if (cond) f else M.point(())

    def sample[R](n: Int): Cont[R, Int] = callCC { cc: (Int => Cont[R, Int]) =>
        for {
            _ <- when (odd(n)) {
                for {
                    _ <- cc(n) // (1)
                } yield ()
            }
        } yield (n * 10) // (2)
    }

    sample(1).runCont { println } // (1)
    sample(2).runCont { println } // (2)
    sample(3).runCont { println } // (1)
    sample(4).runCont { println } // (2)
}

なお、こちらのサンプルをコンパイルする際には -feature オプションを指定します。(指定しないと warning が出ます)

実行結果
> sbt compile -feature
・・・
> sbt console
・・・
scala> fits.sample.CallCCSample1a.main(null)
1
20
3
40

callCC の処理2 (簡易版)

次は、callCC をネストさせたサンプルです。

CallCCSample2.scala (callCC のHaskell簡易版サンプル2)
package fits.sample

import scalaz._
import Scalaz._

object CallCCSample2 extends App {
    import Cont._

    def sample[R](n: Int): Cont[R, Int] = callCC { cc1: (Int => Cont[R, Int]) =>
        if(n % 2 == 1) {
            cc1(n) // (1)
        }
        else {
            for {
                x <- callCC { cc2: (Int => Cont[R, Int]) =>
                    n match {
                        case x if (x < 4) => cc2(n * 1000) // (2)
                        case 4 => cc1(n * 100) // (3)
                        case _ => contInstance[R].point(n * 10) // (4)
                    }
                }
            } yield (x + 1) // (5)
        }
    }

    sample(1).runCont { println } // (1)
    sample(2).runCont { println } // (2) (5)
    sample(3).runCont { println } // (1)
    sample(4).runCont { println } // (3)
    sample(5).runCont { println } // (1)
    sample(6).runCont { println } // (4) (5)
}

sample の処理内容は以下のようになります。

  • 引数 n が奇数なら n の値を継続に適用する継続モナドを返す (1)
  • 引数 n が偶数の場合
    • 4 より小さいと 1000 倍した値に +1 した値を継続に適用する継続モナドを返す (2) (5)
    • 4 なら 100 倍した値を継続に適用する継続モナドを返す (3)
    • それ以外は 10 倍した値に +1 した値を継続に適用する継続モナドを返す (4) (5)

(3) のように 2つ目の callCC 内で cc1 を呼び出すと (5) は実行されず、(2) のように cc2 を呼び出した場合は (5) が適用される事になります。

実行結果
scala> fits.sample.CallCCSample2.main(null)
1
2001
3
400
5
61

callCC の処理2a (Haskell のサンプルに近づけた版)

こちらも when 等を定義して Haskell のサンプルに近づけてみました。

CallCCSample2a.scala (callCC のHaskell近似版サンプル2)
package fits.sample

import scalaz._
import Scalaz._

object CallCCSample2a extends App {
    import Cont._

    val odd = (n: Int) => n % 2 == 1

    def when[M[_], A](cond: Boolean)(f: => M[A])(implicit M: Pointed[M]) = 
        if (cond) f else M.point(())

    def sample[R](n: Int): Cont[R, Int] = callCC { cc1: (Int => Cont[R, Int]) =>
        for {
            _ <- when (odd(n)) {
                for {
                    _ <- cc1(n) // (1)
                } yield ()
            }
            x <- callCC { cc2: (Int => Cont[R, Int]) =>
                for {
                    _ <- when (n < 4) {
                        for {
                            _ <- cc2(n * 1000) // (2)
                        } yield ()
                    }
                    _ <- when (n == 4) {
                        for {
                            _ <- cc1(n * 100) // (3)
                        } yield ()
                    }
                } yield (n * 10) // (4)
            }
        } yield (x + 1) // (5)
    }

    sample(1).runCont { println } // (1)
    sample(2).runCont { println } // (2) (5)
    sample(3).runCont { println } // (1)
    sample(4).runCont { println } // (3)
    sample(5).runCont { println } // (1)
    sample(6).runCont { println } // (4) (5)
}
実行結果
> sbt compile -feature
・・・
> sbt console
・・・
scala> fits.sample.CallCCSample2a.main(null)
1
2001
3
400
5
61