Haskell で継続モナド

継続渡し形式 (CPS) をモナドとして扱う継続モナドHaskell で試してみました。
継続モナドは以下のような処理をモナド化します。

  • 何らかの処理結果を引数として継続と呼ばれる関数を呼び出す(継続は外部から与える)

処理結果を引数にコールバック関数が呼ばれるようなイメージで捉えた方が分かり易いかもしれません。

サンプルソースは http://github.com/fits/try_samples/tree/master/blog/20121104/

>>= の処理

とりあえずバインド関数 >>= を使った簡単なサンプルを書いてみました。

sample.hs
import Control.Monad.Cont

calc1 :: Int -> Cont r Int
calc1 x = return (x + 3)

calc2 :: Int -> Cont r Int
calc2 x = return (x * 10)

calc3 :: Int -> Cont r Int
calc3 x = return (x + 4)

calcAll :: Int -> Cont r Int
calcAll x = return x >>= calc1 >>= calc2 >>= calc3

main = do
    -- a. 2 + 3 = 5
    runCont (calc1 2) print

    -- b. ((2 + 3) * 10) + 4 = 54
    runCont (calcAll 2) print
    -- 上記は以下と同じ
    -- runCont (calcAll 2) (\x -> print x)

    -- c. (((2 + 3) * 10) + 4) - 9 = 45
    print $ runCont (calcAll 2) (\x -> x - 9)

a. は calc1 2 で得た継続モナドから runCont で取り出した値 (継続渡し形式の処理) に継続 (print 関数) を渡しており、x(=2) + 3 の結果を引数として print が実行される事になります。

b. は calcAll 2 で得た継続モナドから取り出した継続渡し形式の処理に print を渡しており、以下のような処理が実施される事になります。

  1. 2 を引数に calc1 実行
  2. calc1 の結果を引数に calc2 実行
  3. calc2 の結果を引数に calc3 実行
  4. calc3 の結果を引数に print 実行

>>= で処理を繋げる事で、処理結果を次の処理に渡していき最終的に継続 (a. や b. における print) を呼び出すような処理を実装できます。

c. は print の代わりに \x -> x - 9 というラムダ式を継続として渡しており、これによって継続渡し形式の処理結果が 54 - 9 = 45 となります。

実行結果
> runghc sample.hs
5
54
45

callCC の処理1

次に callCC 関数を使った簡単なサンプルを書いてみました。

callcc_sample1.hs
import Control.Monad.Cont

sample :: Int -> Cont r Int
sample n = callCC $ \cc -> do
    when (odd n) $ do
        -- (1)
        cc n

    -- (2)
    return (n * 10)

main = do
    runCont (sample 1) print -- (1)
    runCont (sample 2) print -- (2)
    runCont (sample 3) print -- (1)
    runCont (sample 4) print -- (2)

callCC は ((a -> Cont r b) -> Cont r a) を引数にとって Cont r a を返す関数で、上記の callCC に渡しているラムダ式の cc が (a -> Cont r b) に該当します。

cc が呼び出されると callCC 内の残りの処理がスキップされ、cc の引数に渡された値 (上記の n) を継続に適用する継続モナドが返ります。

つまり、sample 関数の処理内容は以下のようになります。

  • 引数 n が奇数なら n の値を継続に適用する継続モナドを返す (1)
  • 引数 n が偶数なら n * 10 の値を継続に適用する継続モナドを返す (2)

実行結果は以下の通り。
奇数ならそのままの値、偶数なら 10 倍した値が出力される事を確認できます。

実行結果
> runghc callcc_sample1.hs
1
20
3
40

callCC の処理2

最後に callCC をネストさせたサンプルを書いてみました。

callCC をネストさせる事で、ある程度複雑な制御構造を実現できそうですが、コードが分かり難くなる点に注意が必要だと思います。

callcc_sample2.hs
import Control.Monad.Cont

sample :: Int -> Cont r Int
sample n = callCC $ \cc1 -> do
    when (odd n) $ do
        -- (1)
        cc1 n

    x <- callCC $ \cc2 -> do
        when (n < 4) $ do
            -- (2)
            cc2 (n * 1000)

        when (n == 4) $ do
            -- (3)
            cc1 (n * 100)

        -- (4)
        return (n * 10)

    -- (5)
    return (x + 1)

main = do
    runCont (sample 1) print -- (1)
    runCont (sample 2) print -- (2) (5)
    runCont (sample 3) print -- (1)
    runCont (sample 4) print -- (3)
    runCont (sample 5) print -- (1)
    runCont (sample 6) print -- (4) (5)

この sample 関数の処理内容は以下のようになります。

  • 引数 n が奇数なら n の値を継続に適用する継続モナドを返す (1)
  • 引数 n が偶数の場合
    • 4 より小さいと 1000 倍した値に +1 した値を継続に適用する継続モナドを返す (2) (5)
    • 4 なら 100 倍した値を継続に適用する継続モナドを返す (3)
    • それ以外は 10 倍した値に +1 した値を継続に適用する継続モナドを返す (4) (5)

(3) のように 2つ目の callCC 内で cc1 を呼び出すと残処理は全てスキップされる事になりますが、(2) のように cc2 を呼び出した場合は 2つ目の callCC の残処理がスキップされるだけでその後の処理 (5) が適用される事になります。

実行結果
> runghc callcc_sample2.hs
1
2001
3
400
5
61