Scalaz で継続モナド
以前(id:fits:20121104)、Haskell で実装した継続モナドのサンプルを Scalaz で実装してみました。
なお、今のところ Scalaz に継続モナドは用意されていないようで、id:fits:20121111 のような方法で自作する必要がありました。
ただし、実際のところ Scala で継続モナドが必要となるケースは基本的に無さそうな気がします。
サンプルソースは http://github.com/fits/try_samples/tree/master/blog/20121125/
今回使用した sbt 用のビルドファイルは以下です。
build.sbt
scalaVersion := "2.10.0-RC2" libraryDependencies += "org.scalaz" % "scalaz-core" % "7.0.0-M4" cross CrossVersion.full
継続モナドの実装
はじめに、継続モナドを実装します。
Haskell の実装を参考に Cont モナドと callCC を定義してみました。
Cont.scala (継続モナドの定義)
package fits.sample import scalaz._ import Scalaz._ // モナドとして扱う型の定義 case class Cont[R, A](runCont: (A => R) => R) // 継続モナドの注入関数・連鎖関数を実装 trait ContMonad[R] extends Monad[({type r[a] = Cont[R, a]})#r] { // 注入関数の実装 def point[A](a: => A) = Cont { k => k(a) } // 連鎖関数の実装 def bind[A, B](fa: Cont[R, A])(f: (A) => Cont[R, B]) = { Cont { k => fa.runCont { a => f(a).runCont(k) } } } } trait ContFunctions { // callCC の実装 def callCC[R, A, B](f: (A => Cont[R, B]) => Cont[R, A]): Cont[R, A] = { Cont { k => f { a => Cont { _ => k(a) } }.runCont(k) } } } trait ContInstances { // 継続モナドのインスタンス定義 implicit def contInstance[R] = new ContMonad[R] { } } case object Cont extends ContFunctions with ContInstances
バインド >>= の処理
id:fits:20121104 と同様にバインド関数 >>= を使った簡単なサンプルを書いてみました。
Sample.scala (バインド処理のサンプル)
package fits.sample import scalaz._ import Scalaz._ object Sample extends App { import Cont._ // 継続モナドの作成 def cont[R](a: Int) = contInstance[R].point(a) def calc1[R](x: Int) = cont[R](x + 3) def calc2[R](x: Int) = cont[R](x * 10) def calc3[R](x: Int) = cont[R](x + 4) def calcAll[R](x: Int) = cont[R](x) >>= calc1 >>= calc2 >>= calc3 calc1(2).runCont { println } // a. 2 + 3 = 5 calcAll(2).runCont { println } // b. ((2 + 3) * 10) + 4 = 54 calcAll(2).runCont { x => x - 9 } |> println // c. 54 - 9 = 45 }
a. は calc1(2) で得た継続モナドから runCont で取り出した値 (継続渡し形式の処理) に継続 (println 関数) を渡しており、x(=2) + 3 の結果を引数として println が実行される事になります。
b. は calcAll(2) で得た継続モナドから取り出した値に println を渡しており、以下のような処理が実施されます。
- 2 を引数に calc1 実行
- calc1 の結果を引数に calc2 実行
- calc2 の結果を引数に calc3 実行
- calc3 の結果を引数に println 実行
>>= で処理を繋げる事で、処理結果を次の処理に渡していき最終的に継続 (a. や b. における println) を呼び出すような処理を実装できます。
c. は println の代わりに x - 9 という処理を継続として渡しており、これによって継続渡し形式の処理結果が 54 - 9 = 45 となります。
実行結果
> sbt console ・・・ scala> fits.sample.Sample.main(null) 5 54 45
callCC の処理1 (簡易版)
callCC を使った単純なサンプルを実装します。
CallCCSample1.scala (callCC の簡易版サンプル1)
package fits.sample import scalaz._ import Scalaz._ object CallCCSample1 extends App { import Cont._ def sample[R](n: Int): Cont[R, Int] = callCC { cc: (Int => Cont[R, Int]) => if(n % 2 == 1) { cc(n) // (1) } else { contInstance[R].point(n * 10) // (2) } } sample(1).runCont { println } // (1) sample(2).runCont { println } // (2) sample(3).runCont { println } // (1) sample(4).runCont { println } // (2) }
sample の処理内容は以下のようになっています。
実行結果は以下の通りです。
奇数ならそのままの値、偶数なら 10 倍した値が出力されます。
実行結果
scala> fits.sample.CallCCSample1.main(null) 1 20 3 40
callCC の処理1a (Haskell のサンプルに近づけた版)
実は、上記 callCC のサンプルは id:fits:20121104 の Haskell 版と実装方法が結構違っていましたので、when 等も定義して近づけてみました。(かなり分かり難くなってしまいましたが)
CallCCSample1a.scala (callCC のHaskell近似版サンプル1)
package fits.sample import scalaz._ import Scalaz._ object CallCCSample1a extends App { import Cont._ val odd = (n: Int) => n % 2 == 1 def when[M[_], A](cond: Boolean)(f: => M[A])(implicit M: Pointed[M]) = if (cond) f else M.point(()) def sample[R](n: Int): Cont[R, Int] = callCC { cc: (Int => Cont[R, Int]) => for { _ <- when (odd(n)) { for { _ <- cc(n) // (1) } yield () } } yield (n * 10) // (2) } sample(1).runCont { println } // (1) sample(2).runCont { println } // (2) sample(3).runCont { println } // (1) sample(4).runCont { println } // (2) }
なお、こちらのサンプルをコンパイルする際には -feature オプションを指定します。(指定しないと warning が出ます)
実行結果
> sbt compile -feature ・・・ > sbt console ・・・ scala> fits.sample.CallCCSample1a.main(null) 1 20 3 40
callCC の処理2 (簡易版)
次は、callCC をネストさせたサンプルです。
CallCCSample2.scala (callCC のHaskell簡易版サンプル2)
package fits.sample import scalaz._ import Scalaz._ object CallCCSample2 extends App { import Cont._ def sample[R](n: Int): Cont[R, Int] = callCC { cc1: (Int => Cont[R, Int]) => if(n % 2 == 1) { cc1(n) // (1) } else { for { x <- callCC { cc2: (Int => Cont[R, Int]) => n match { case x if (x < 4) => cc2(n * 1000) // (2) case 4 => cc1(n * 100) // (3) case _ => contInstance[R].point(n * 10) // (4) } } } yield (x + 1) // (5) } } sample(1).runCont { println } // (1) sample(2).runCont { println } // (2) (5) sample(3).runCont { println } // (1) sample(4).runCont { println } // (3) sample(5).runCont { println } // (1) sample(6).runCont { println } // (4) (5) }
sample の処理内容は以下のようになります。
- 引数 n が奇数なら n の値を継続に適用する継続モナドを返す (1)
- 引数 n が偶数の場合
(3) のように 2つ目の callCC 内で cc1 を呼び出すと (5) は実行されず、(2) のように cc2 を呼び出した場合は (5) が適用される事になります。
実行結果
scala> fits.sample.CallCCSample2.main(null) 1 2001 3 400 5 61
callCC の処理2a (Haskell のサンプルに近づけた版)
こちらも when 等を定義して Haskell のサンプルに近づけてみました。
CallCCSample2a.scala (callCC のHaskell近似版サンプル2)
package fits.sample import scalaz._ import Scalaz._ object CallCCSample2a extends App { import Cont._ val odd = (n: Int) => n % 2 == 1 def when[M[_], A](cond: Boolean)(f: => M[A])(implicit M: Pointed[M]) = if (cond) f else M.point(()) def sample[R](n: Int): Cont[R, Int] = callCC { cc1: (Int => Cont[R, Int]) => for { _ <- when (odd(n)) { for { _ <- cc1(n) // (1) } yield () } x <- callCC { cc2: (Int => Cont[R, Int]) => for { _ <- when (n < 4) { for { _ <- cc2(n * 1000) // (2) } yield () } _ <- when (n == 4) { for { _ <- cc1(n * 100) // (3) } yield () } } yield (n * 10) // (4) } } yield (x + 1) // (5) } sample(1).runCont { println } // (1) sample(2).runCont { println } // (2) (5) sample(3).runCont { println } // (1) sample(4).runCont { println } // (3) sample(5).runCont { println } // (1) sample(6).runCont { println } // (4) (5) }
実行結果
> sbt compile -feature ・・・ > sbt console ・・・ scala> fits.sample.CallCCSample2a.main(null) 1 2001 3 400 5 61