読者です 読者をやめる 読者になる 読者になる

Java のアノテーションプロセッサで Haskell の do 記法のようなものを簡易的に実現2

Java Monad

前回 に引き続き、今回も Javaアノテーションプロセッサを使って下記と同等機能を実現します。

  • Haskell の do 記法
  • Scala の for 内包表記
  • F# のコンピュテーション式

今回は、F# のコンピュテーション式を模した下記のような構文 (前回断念したもの) を使用します。

Supplier<Optional<Integer>> res = opt$do -> {
    let a = o1;
    let b = o2;
    return a + b;
};

ソースは http://github.com/fits/try_samples/tree/master/blog/20150513/

改良版は 「Java のアノテーションプロセッサで Haskell の do 記法のようなものを簡易的に実現3」 を参照

はじめに

基本的な変換方法は 前回 と同じですが、かなりシンプルになっていると思います。

変数名$do の $do は変換対象としてマーキングするために付けています。

変換前 (アノテーションプロセッサ処理前)
Supplier<Optional<Integer>> res = opt$do -> {
    let a = o1;
    let b = o2;
    return a + b;
};
変換後 (アノテーションプロセッサ処理後)
Supplier<Optional<Integer>> res = () -> opt.bind(o1, a -> opt.bind(o2, b -> opt.unit(a + b)));

アノテーションプロセッサの実装

Processor の実装

アノテーションプロセッサの本体は 前回 と同じです。

src/main/java/sample/DoExprProcessor.java
package sample;

import java.util.Set;
import javax.annotation.processing.*;

import javax.lang.model.SourceVersion;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;

import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.util.Trees;
import com.sun.source.util.TreePath;

import com.sun.tools.javac.processing.JavacProcessingEnvironment;
import com.sun.tools.javac.util.Context;

@SupportedSourceVersion(SourceVersion.RELEASE_8)
@SupportedAnnotationTypes("*")
public class DoExprProcessor extends AbstractProcessor {
    private Trees trees;
    private Context context;

    @Override
    public void init(ProcessingEnvironment procEnv) {
        trees = Trees.instance(procEnv);
        context = ((JavacProcessingEnvironment)procEnv).getContext();
    }

    @Override
    public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
        roundEnv.getRootElements().stream().map(this::toUnit).forEach(this::processUnit);
        return false;
    }

    private CompilationUnitTree toUnit(Element el) {
        TreePath path = trees.getPath(el);
        return path.getCompilationUnit();
    }

    private void processUnit(CompilationUnitTree cu) {
        // AST 変換
        cu.accept(new DoExprVisitor(context), null);
        // 変換後のソースを出力
        System.out.println(cu);
    }
}

TreeVisitor の実装

基本的な変換内容は 前回 と同じですが、下記の点が異なります。

  • (1) 対象処理を変換したソースコードを作って JCExpression へパース
  • (2) 生成した JCExpression 内の全 pos の値を修正
  • (3) JCLambda の body を差し替え

(2) が重要で、posソースコード内の位置) の値を調整しておかないと変換後の AST をコンパイルする段階でエラーになります。 (前回失敗した理由)

新しく生成した JCExpression木構造をたどって全要素の pos を変更するために com.sun.tools.javac.tree.TreeScannerscan メソッドをオーバーライドして使っています。

また、今回の構文ではラムダの paramKindIMPLICIT となりますので(前回はラムダ引数の型を指定していたので EXPLICIT だった)、ラムダの引数を消去した際に paramKindEXPLICIT へ変更しています。

src/main/java/sample/DoExprVisitor.java
package sample;

import com.sun.source.tree.*;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.parser.ParserFactory;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.*;
import com.sun.tools.javac.util.Context;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

public class DoExprVisitor extends TreeScanner<Void, Void> {
    private static final String DO_TYPE = "$do";

    private ParserFactory parserFactory;
    private Map<String, TemplateBuilder> builderMap = new HashMap<>();

    public DoExprVisitor(Context context) {
        parserFactory = ParserFactory.instance(context);
        // let 用の変換内容
        builderMap.put("let",
                new TemplateBuilder("${var}.bind(${rExpr}, ${lExpr} -> ${body})", this::createBindParams));
        // return 用の変換内容
        builderMap.put("return",
                new TemplateBuilder("${var}.unit( ${expr} )", this::createBasicParams));
    }

    @Override
    public Void visitLambdaExpression(LambdaExpressionTree node, Void p) {
        if (node instanceof JCLambda) {
            JCLambda lm = (JCLambda)node;

            if (lm.params.size() == 1) {
                getDoVar(lm.params.get(0)).ifPresent(var -> {
                    // ラムダの引数を消去
                    lm.params = com.sun.tools.javac.util.List.nil();
                    lm.paramKind = JCLambda.ParameterKind.EXPLICIT;

                    // (1) 対象処理を変換したソースコードを作って JCExpression へパース
                    JCExpression ne = parseExpression(createExpression((JCBlock)lm.body, var));
                    // (2) 生成した JCExpression 内の全 pos の値を修正
                    fixPos(ne, lm.pos);
                    // (3) JCLambda の body を差し替え
                    lm.body = ne;
                });
            }
        }
        return super.visitLambdaExpression(node, p);
    }

    // pos の値を修正する
    private void fixPos(JCExpression ne, int basePos) {
        ne.accept(new com.sun.tools.javac.tree.TreeScanner() {
            @Override
            public void scan(JCTree tree) {
                if(tree != null) {
                    tree.pos += basePos;
                    super.scan(tree);
                }
            }
        });
    }

    // 対象処理を変換したソースコード (Expression) を生成
    private String createExpression(JCBlock block, String var) {
        Stream<String> revExpr = block.stats.reverse().stream().map(s -> s.toString().replaceAll(";", ""));

        return revExpr.reduce("", (acc, v) -> {
            int spacePos = v.indexOf(" ");
            String action = v.substring(0, spacePos);

            if (builderMap.containsKey(action)) {
                acc = builderMap.get(action).build(var, acc, v.substring(spacePos + 1));
            }

            return acc;
        });
    }

    // 生成したソースコード (Expression) を JavacParser で JCExpression へ変換
    private JCExpression parseExpression(String doExpr) {
        return parserFactory.newParser(doExpr, false, false, false).parseExpression();
    }

    private Optional<String> getDoVar(JCVariableDecl param) {
        String name = param.name.toString();

        return name.endsWith(DO_TYPE)? Optional.of(name.replace(DO_TYPE, "")): Optional.empty();
    }

    private Map<String, String> createBindParams(String var, String body, String expr) {
        Map<String, String> params = createBasicParams(var, body, expr);

        String[] vexp = expr.split("=");
        params.put("lExpr", vexp[0]);
        params.put("rExpr", vexp[1]);

        return params;
    }

    private Map<String, String> createBasicParams(String var, String body, String expr) {
        Map<String, String> params = new HashMap<>();

        params.put("var", var);
        params.put("body", body);
        params.put("expr", expr);

        return params;
    }

    private interface ParamCreator {
        Map<String, String> create(String var, String body, String expr);
    }

    private class TemplateBuilder {
        private static final String VAR_PREFIX = "\\$\\{";
        private static final String VAR_SUFFIX = "\\}";

        private String template;
        private ParamCreator paramCreator;

        TemplateBuilder(String template, ParamCreator paramCreator) {
            this.template = template;
            this.paramCreator = paramCreator;
        }

        public String build(String var, String body, String expr) {
            return buildTemplate(template, paramCreator.create(var, body, expr));
        }

        private String buildTemplate(String template, Map<String, String> params) {
            return params.entrySet().stream().reduce(template,
                    (acc, v) -> acc.replaceAll(VAR_PREFIX + v.getKey() + VAR_SUFFIX, v.getValue()),
                    (a, b) -> a);
        }
    }
}

Service Provider 設定ファイルやビルド定義も 前回 と同じものです。

Service Provider 設定ファイル

src/main/resources/META-INF/services/javax.annotation.processing.Processor
sample.DoExprProcessor
build.gradle
apply plugin: 'java'

def enc = 'UTF-8'
tasks.withType(AbstractCompile)*.options*.encoding = enc

dependencies {
    compile files("${System.properties['java.home']}/../lib/tools.jar")
}

ビルド

ビルド実行
> gradle build

:compileJava
:processResources UP-TO-DATE
:classes
:jar
:assemble
:compileTestJava UP-TO-DATE
:processTestResources UP-TO-DATE
:testClasses UP-TO-DATE
:test UP-TO-DATE
:check UP-TO-DATE
:build

BUILD SUCCESSFUL

ビルド結果として build/libs/java_do_expr.jar2 が生成されました。

動作確認

下記のサンプルコードを使ってアノテーションプロセッサの動作確認を行います。

example/DoExprSample.java
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.Optional;

public class DoExprSample {
    public static void main(String... args) {
        Optional<Integer> o1 = Optional.of(2);
        Optional<Integer> o2 = Optional.of(3);

        Opt<Integer> opt = new Opt<>();
        // アノテーションプロセッサで変換する処理1
        Supplier<Optional<Integer>> res = opt$do -> {
            let a = o1;
            let b = o2;
            let c = Optional.of(4);
            return a + b + c * 2;
        };

        // Optional[13]
        System.out.println(res.get());

        Opt<String> opt2 = new Opt<>();
        // アノテーションプロセッサで変換する処理2
        Supplier<Optional<String>> res2 = opt2$do -> {
            let a = Optional.of("a");
            let b = Optional.of("b");
            return a + b;
        };

        // Optional["ab"]
        System.out.println(res2.get());
    }
    // Optional 用の bind・unit メソッド実装クラス
    static class Opt<T> {
        public Optional<T> bind(Optional<T> x, Function<T, Optional<T>> f) {
            return x.flatMap(f);
        }

        public Optional<T> unit(T v) {
            return Optional.ofNullable(v);
        }
    }
}

java_do_expr2.jar を使って上記ソースファイルをコンパイルします。

出力内容(変換後のソースコード)を見る限り正常に変換できているようです。

コンパイル
> javac -cp ../build/libs/java_do_expr2.jar DoExprSample.java

・・・
public class DoExprSample {
    ・・・
    public static void main(String... args) {
        Optional<Integer> o1 = Optional.of(2);
        Optional<Integer> o2 = Optional.of(3);
        Opt<Integer> opt = new Opt<>();
        Supplier<Optional<Integer>> res = ()->opt.bind(o1, (a)->opt.bind(o2, (b)->opt.bind(Optional.of(4), (c)->opt.unit(a + b + c * 2))));
        System.out.println(res.get());
        Opt<String> opt2 = new Opt<>();
        Supplier<Optional<String>> res2 = ()->opt2.bind(Optional.of("a"), (a)->opt2.bind(Optional.of("b"), (b)->opt2.unit(a + b)));
        System.out.println(res2.get());
    }
    ・・・
}

DoExprSample を実行すると正常に動作しました。

実行結果
> java DoExprSample

Optional[13]
Optional[ab]