Java のアノテーションプロセッサで Haskell の do 記法のようなものを簡易的に実現

アノテーションプロセッサで AST 変換 - Lombok を参考にして変数の型をコンパイル時に変更」の応用編です。

前回は変数の型を var から java.lang.Object へ変更しただけでしたが、今回は下記と同等な機能の簡易版をアノテーションプロセッサで実現してみます。

  • Haskell の do 記法
  • Scala の for 内包表記
  • F# のコンピュテーション式

ソースは http://github.com/fits/try_samples/tree/master/blog/20150511/

改良版は 「Java のアノテーションプロセッサで Haskell の do 記法のようなものを簡易的に実現3」 を参照

はじめに

前回 と同様にアノテーションプロセッサにおける下記の特徴を利用し、コンパイル前に AST (抽象構文木) を書き換えます。

変換例

今回は、以下のような変換をアノテーションプロセッサ内で実施します。 Haskell の do 記法というよりは F# のコンピュテーション式に近くなっています。

変換前 (アノテーションプロセッサ処理前)
Supplier<Optional<Integer>> res = ($do<Optional, Integer> opt) -> {
    let a = o1;
    let b = o2;
    return a + b;
};
変換後 (アノテーションプロセッサ処理後)
Supplier<Optional<Integer>> res = ()->{
    return opt.bind(o1, new java.util.function.Function<Integer, Optional<Integer>>(){
        @Override
        public Optional<Integer> apply(Integer a) {
            return opt.bind(o2, new java.util.function.Function<Integer, Optional<Integer>>(){
                @Override
                public Optional<Integer> apply(Integer b) {
                    return opt.unit(a + b);
                }
            });
        }
    });
};

変換内容

変換前の構文と変換方法を簡単に説明します。

ラムダの引数部分は変換に必要な情報を渡すために使用し、$do という未定義のクラスは変換対象かどうかをマーキングする目的で使います。

処理変数(変換例の opt)には bind・unit メソッドを持つオブジェクトインスタンスの変数名を指定します。

また、型推論は難しそうだったので、コンテナ型 (モナドの型) と要素の型を $do クラスの型引数で指定するようにしています。

変換前
・・・ = ($do<コンテナ型, 要素型> 処理変数) -> {
    let 変数1 = 式1;
    ・・・
    return 結果式;
};

変換は以下のように実施します。

ラムダの引数部分 ($do の箇所) を全て消去します。 (変換に必要な情報を渡しているだけなので)

(a) ($do<コンテナ型, 要素型> 処理変数) -> {} の変換内容
・・・ = () -> {
    ・・・
}

let の部分は bind メソッドを使った処理へ変換します。

(b) let 変数1 = 式1; の変換内容
return 処理変数.bind(式1, new java.util.function.Function<要素型, コンテナ型<要素型>>(){
    @Override
    public コンテナ型<要素型> apply(要素型 変数1) {
        ・・・
    }
});

return の部分は unit メソッドを使った処理へ変換します。

(c) return 結果式; の変換内容
return 処理変数.unit( 結果式 );

備考

本当は以下のようなシンプルな仕様で実現したかったのですが、変換後のコンパイル時にエラーが発生してしまい、うまく解決できなかったので今回は断念しました。

変換前 (失敗版)
Supplier<Optional<Integer>> res = opt$do -> {
    let a = o1;
    let b = o2;
    return a + b;
};
変換後 (失敗版)
Supplier<Optional<Integer>> res = () -> opt.bind(o1, a -> opt.bind(o2, b -> opt.unit(a + b)));
変換後のコンパイルエラー例 (Java 1.8.0_45)
java.lang.AssertionError: Value of x -1
        at com.sun.tools.javac.util.Assert.error(Assert.java:133)
        at com.sun.tools.javac.util.Assert.check(Assert.java:94)
        at com.sun.tools.javac.util.Bits.incl(Bits.java:200)
        at com.sun.tools.javac.comp.Flow$AbstractAssignAnalyzer.visitLambda(Flow.java:2254)
        at com.sun.tools.javac.tree.JCTree$JCLambda.accept(JCTree.java:1624)

ただし、変換後のソースをファイルへ出力し javac で普通にコンパイルすれば問題なく成功しますので、実現不可能では無いと思います。

と書きましたが、生成した箇所の pos の値を全て調整し直せば、 上記の構文で実現できることが判明しました。 (Java のアノテーションプロセッサで Haskell の do 記法のようなものを簡易的に実現2 参照)

アノテーションプロセッサの実装

それでは本題に入ります。

Processor の実装

まずは、アノテーションプロセッサの本体を実装します。

こちらは 前回SampleProcessor2.java とほぼ同じ内容です。

変換後のソースを確認するため accept 後の CompilationUnitTreeprintln するようにしています。

src/main/java/sample/DoExprProcessor.java
package sample;

import java.util.Set;
import javax.annotation.processing.*;

import javax.lang.model.SourceVersion;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;

import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.util.Trees;
import com.sun.source.util.TreePath;

import com.sun.tools.javac.processing.JavacProcessingEnvironment;
import com.sun.tools.javac.util.Context;

@SupportedSourceVersion(SourceVersion.RELEASE_8)
@SupportedAnnotationTypes("*")
public class DoExprProcessor extends AbstractProcessor {
    private Trees trees;
    private Context context;

    @Override
    public void init(ProcessingEnvironment procEnv) {
        trees = Trees.instance(procEnv);
        context = ((JavacProcessingEnvironment)procEnv).getContext();
    }

    @Override
    public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
        roundEnv.getRootElements().stream().map(this::toUnit).forEach(this::processUnit);
        return false;
    }

    private CompilationUnitTree toUnit(Element el) {
        TreePath path = trees.getPath(el);
        return path.getCompilationUnit();
    }

    private void processUnit(CompilationUnitTree cu) {
        // AST 変換
        cu.accept(new DoExprVisitor(context), null);
        // 変換後のソースを出力
        System.out.println(cu);
    }
}

TreeVisitor の実装

次に、変換処理を実装します。

ラムダを変更するため TreeScanner を extends し visitLambdaExpression メソッドをオーバーライドします。

LambdaExpressionTree だとラムダの内容を変更できないので JCLambda へキャストし、以下のような処理を実施します。

  • (1) 変換対象かどうかをチェック (引数の型に $do を使っているかどうか等)
  • (2) ラムダの引数を消去
  • (3) ラムダの処理内容(let や return)を元に変換後のソースを生成
  • (4) (3) を JavacParser を使って JCStatement へ変換
  • (5) ラムダの処理内容を (4) の結果で差し替え (JCLambdabody.stats の値を変更)

通常は、JCStatement を直接構築して差し替えると思うのですが、JCStatement を自前で構築するのは大変そうだったので、部分的なソースコードを作り (3) 、それを JavacParser にパースさせる事で JCStatement を得ています (4)。

また、JCTree 内で使用されている List クラスは java.util.List ではなく com.sun.tools.javac.util.List なので注意。

src/main/java/sample/DoExprVisitor.java
package sample;

import com.sun.source.tree.LambdaExpressionTree;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.parser.ParserFactory;
import com.sun.tools.javac.tree.JCTree.*;
import com.sun.tools.javac.util.Context;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Stream;

public class DoExprVisitor extends TreeScanner<Void, Void> {
    private static final String DO_TYPE = "$do";
    // let 変換用のソースコード
    private static final String BIND_CODE = " return ${var}.bind(${rExpr}, new java.util.function.Function<${vType}, ${mType}<${vType}>>(){" +
            "  @Override public ${mType}<${vType}> apply(${vType} ${lExpr}){ ${body} }" +
            " });";
    // return 変換用のソースコード
    private static final String UNIT_CODE = "  return ${var}.unit( ${expr} );";

    private ParserFactory parserFactory;
    private Map<String, TemplateBuilder> builderMap = new HashMap<>();

    public DoExprVisitor(Context context) {
        parserFactory = ParserFactory.instance(context);
        // (b) let 用の変換内容
        builderMap.put("let", new TemplateBuilder(BIND_CODE, this::createBindParams));
        // (c) return 用の変換内容
        builderMap.put("return", new TemplateBuilder(UNIT_CODE, this::createUnitParams));
    }

    @Override
    public Void visitLambdaExpression(LambdaExpressionTree node, Void p) {
        if (node instanceof JCLambda) {
            JCLambda lm = (JCLambda)node;

            // (1) 変換対象かどうかをチェック
            if (isSingleTypeApplyParam(lm)) {
                JCVariableDecl param = lm.params.get(0);

                // (1) 変換対象かどうかをチェック
                if (isDoType(param)) {
                    // (2) ラムダの引数を消去
                    lm.params = com.sun.tools.javac.util.List.nil();

                    JCBlock block = (JCBlock)lm.body;
                    // 変換後の処理内容(Statement)を作成 (3) (4)
                    JCStatement newStats = parseStatement(createStatement(block, createBaseParams(param)));
                    // (5) ラムダの処理内容を差し替え
                    block.stats = com.sun.tools.javac.util.List.of(newStats);
                }
            }
        }
        return super.visitLambdaExpression(node, p);
    }
    // (3) ラムダの処理内容を変換したソースコードを生成
    private String createStatement(JCBlock block, Map<String, String> params) {
        // ラムダの内容 (JCStatement のリスト) を逆順化して個別に文字列化
        Stream<String> revStats = block.stats.reverse().stream().map(s -> s.toString().replaceAll(";", ""));

        // 逆順化したリストを順次変換
        return revStats.reduce("", (acc, v) -> {
            int spacePos = v.indexOf(" ");
            String action = v.substring(0, spacePos);

            if (builderMap.containsKey(action)) {
                acc = builderMap.get(action).build(params, acc, v.substring(spacePos + 1));
            }

            return acc;
        });
    }
    // (4) 生成したソースコードを JavacParser で JCStatement へ変換
    private JCStatement parseStatement(String doStat) {
        return parserFactory.newParser(doStat, false, false, false).parseStatement();
    }
    // (1)
    private boolean isDoType(JCVariableDecl param) {
        String type = ((JCTypeApply)param.vartype).clazz.toString();
        return DO_TYPE.equals(type);
    }
    // (1)
    private boolean isSingleTypeApplyParam(JCLambda lm) {
        return lm.params.size() == 1
                && lm.params.get(0).vartype instanceof JCTypeApply;
    }

    private Map<String, String> createBaseParams(JCVariableDecl param) {
        Map<String, String> params = new HashMap<>();

        params.put("var", param.name.toString());

        JCTypeApply paramType = (JCTypeApply)param.vartype;
        params.put("mType", paramType.arguments.get(0).toString());
        params.put("vType", paramType.arguments.get(1).toString());

        return params;
    }

    private Map<String, String> createBindParams(String body, String expr) {
        Map<String, String> params = createUnitParams(body, expr);

        String[] divexp = expr.split("=");
        params.put("lExpr", divexp[0]);
        params.put("rExpr", divexp[1]);

        return params;
    }

    private Map<String, String> createUnitParams(String body, String expr) {
        Map<String, String> params = new HashMap<>();

        params.put("body", body);
        params.put("expr", expr);

        return params;
    }
    // テンプレート処理を実施するためのクラス
    private class TemplateBuilder {
        private static final String VAR_PREFIX = "\\$\\{";
        private static final String VAR_SUFFIX = "\\}";

        private String template;
        private BiFunction<String, String, Map<String, String>> paramCreator;

        TemplateBuilder(String template, BiFunction<String, String, Map<String, String>> paramCreator) {
            this.template = template;
            this.paramCreator = paramCreator;
        }

        public String build(Map<String, String> params, String body, String expr) {
            return buildTemplate(
                    buildTemplate(template, params),
                    paramCreator.apply(body, expr));
        }

        private String buildTemplate(String template, Map<String, String> params) {
            return params.entrySet().stream().reduce(template,
                    (acc, v) -> acc.replaceAll(VAR_PREFIX + v.getKey() + VAR_SUFFIX, v.getValue()),
                    (a, b) -> a);
        }
    }
}

JCLambda の変換イメージ例

JCLambda の大まかな変換イメージを書くと以下のようになります。

ソースコード
($do<Optional, String> opt2) -> {
    let a = Optional.of("a");
    let b = Optional.of("b");
    return a + b;
}

上記のソースを JCLambda 化すると以下のようになります。 実際はもっと複雑ですが、適当に簡略化しています。 (init や expr の内容も実際は JCMethodInvocation 等が入れ子になっています)

変換前 JCLambda の内容
JCLambda(
    params = [
        JCVariableDecl(
            name = opt2,
            vartype = $do<Optional, String>
        )
    ],
    paramKind = EXPLICIT,
    body = JCBlock(
        stats = [
            JCVariableDecl(
                name = a,
                vartype = let,
                init = Optional.of("a")
            ),
            JCVariableDecl(
                name = b,
                vartype = let,
                init = Optional.of("a")
            ),
            JCReturn(
                expr = a + b
            )
        ]
    )
)

DoExprVisitor では上記を以下のように変更します。

変換後 JCLambda の内容
JCLambda(
    params = [],
    paramKind = EXPLICIT,
    body = JCBlock(
        stats = [
            JCReturn(
                expr = opt2.bind(Optional.of("a"), ・・・)
            )
        ]
    )
)

今回は body に設定されている JCBlock をそのまま使いましたが、body の値を直接変更する方法も考えられます。

Service Provider 設定ファイル

アノテーションプロセッサで sample.DoExprProcessor を使用するように Service Provider の設定ファイルを用意します。

src/main/resources/META-INF/services/javax.annotation.processing.Processor
sample.DoExprProcessor

ビルド

ビルド定義ファイルは以下の通り。

build.gradle
apply plugin: 'java'

def enc = 'UTF-8'
tasks.withType(AbstractCompile)*.options*.encoding = enc

dependencies {
    compile files("${System.properties['java.home']}/../lib/tools.jar")
}
ビルド実行
> gradle build

:compileJava
:processResources UP-TO-DATE
:classes
:jar
:assemble
:compileTestJava UP-TO-DATE
:processTestResources UP-TO-DATE
:testClasses UP-TO-DATE
:test UP-TO-DATE
:check UP-TO-DATE
:build

BUILD SUCCESSFUL

ビルド結果として build/libs/java_do_expr.jar が生成されました。

動作確認

最後に、下記のサンプルコードを使ってアノテーションプロセッサの動作確認を行います。

example/DoExprSample.java
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.Optional;

public class DoExprSample {
    public static void main(String... args) {
        Optional<Integer> o1 = Optional.of(2);
        Optional<Integer> o2 = Optional.of(3);

        Opt<Integer> opt = new Opt<>();
        // アノテーションプロセッサで変換する処理1
        Supplier<Optional<Integer>> res = ($do<Optional, Integer> opt) -> {
            let a = o1;
            let b = o2;
            let c = Optional.of(4);
            return a + b + c * 2;
        };

        // Optional[13]
        System.out.println(res.get());

        Opt<String> opt2 = new Opt<>();
        // アノテーションプロセッサで変換する処理2
        Supplier<Optional<String>> res2 = ($do<Optional, String> opt2) -> {
            let a = Optional.of("a");
            let b = Optional.of("b");
            return a + b;
        };

        // Optional["ab"]
        System.out.println(res2.get());
    }
    // Optional 用の bind・unit メソッド実装クラス
    static class Opt<T> {
        public Optional<T> bind(Optional<T> x, Function<T, Optional<T>> f) {
            return x.flatMap(f);
        }

        public Optional<T> unit(T v) {
            return Optional.ofNullable(v);
        }
    }
}

java_do_expr.jar を使って上記ソースファイルをコンパイルします。

出力内容(変換後のソースコード)を見る限り正常に変換できているようです。

コンパイル
> javac -cp ../build/libs/java_do_expr.jar DoExprSample.java

・・・
public class DoExprSample {
    ・・・
    public static void main(String... args) {
        Optional<Integer> o1 = Optional.of(2);
        Optional<Integer> o2 = Optional.of(3);
        Opt<Integer> opt = new Opt<>();
        Supplier<Optional<Integer>> res = ()->{
            return opt.bind(o1, new java.util.function.Function<Integer, Optional<Integer>>(){

                @Override()
                public Optional<Integer> apply(Integer a) {
                    return opt.bind(o2, new java.util.function.Function<Integer, Optional<Integer>>(){

                        @Override()
                        public Optional<Integer> apply(Integer b) {
                            return opt.bind(Optional.of(4), new java.util.function.Function<Integer, Optional<Integer>>(){

                                @Override()
                                public Optional<Integer> apply(Integer c) {
                                    return opt.unit(a + b + c * 2);
                                }
                            });
                        }
                    });
                }
            });
        };
        System.out.println(res.get());
        Opt<String> opt2 = new Opt<>();
        Supplier<Optional<String>> res2 = ()->{
            return opt2.bind(Optional.of("a"), new java.util.function.Function<String, Optional<String>>(){

                @Override()
                public Optional<String> apply(String a) {
                    return opt2.bind(Optional.of("b"), new java.util.function.Function<String, Optional<String>>(){

                        @Override()
                        public Optional<String> apply(String b) {
                            return opt2.unit(a + b);
                        }
                    });
                }
            });
        };
        System.out.println(res2.get());
    }
    ・・・
}

DoExprSample を実行すると正常に動作しました。

実行結果
> java DoExprSample

Optional[13]
Optional[ab]