Java のアノテーションプロセッサで Haskell の do 記法のようなものを簡易的に実現3

Javaアノテーションプロセッサを使って下記と同等の機能を実現する試みの第三弾です。

  • Haskell の do 記法
  • Scala の for 内包表記
  • F# のコンピュテーション式

前回 のものを改良し、ようやく下記のような構文を実現しました。

Optional<String> res = opt$do -> {
    let a = Optional.of("a");
    let b = Optional.of("b");
    let c = opt$do -> {
        let c1 = Optional.of("c1");
        let c2 = Optional.of("c2");
        return c1 + "-" + c2;
    };
    return a + b + "/" + c;
};

ソースは http://github.com/fits/try_samples/tree/master/blog/20150516/

はじめに

環境

下記のような環境を使ってビルド・実行しています。

  • JavaSE Development Kit 8u45 (1.8.0_45)
  • Gradle 2.4

構文

前回 からの変更点は以下の通りです。

  • 対象のラムダ式JCLambda)を完全に置換し、Supplier を不要にした
  • let で入れ子に対応

対象のラムダ式を別の式 (JCMethodInvocation) で完全に置換し、Supplier を無くした事でまともな構文になったと思います。

変換前 (アノテーションプロセッサ処理前)
Optional<String> res = opt$do -> {
    let a = Optional.of("a");
    let b = Optional.of("b");
    let c = opt$do -> {
        let c1 = Optional.of("c1");
        let c2 = Optional.of("c2");
        return c1 + "-" + c2;
    };
    return a + b + "/" + c;
};
変換後 (アノテーションプロセッサ処理後)
Optional<String> res = opt.bind(
    Optional.of("a"), 
    (a) -> opt.bind(
        Optional.of("b"), 
        (b) -> opt.bind(
            opt.bind(
                Optional.of("c1"), 
                (c1) -> opt.bind(
                    Optional.of("c2"), 
                    (c2) -> opt.unit(c1 + "-" + c2)
                )
            ), 
            (c) -> opt2.unit(a + b + "/" + c)
        )
    )
);

また、変数への代入だけではなく、メソッドの引数にも上記構文を使えるようにしました。

メソッド引数としての使用例
System.out.println(opt$do -> {
    let a = Optional.of("a");
    let b = Optional.of("b");
    return "***" + b + a;
});

アノテーションプロセッサの実装

Processor の実装

前回 とほぼ同じですが、DoExprVisitor の extends 元を com.sun.tools.javac.tree.TreeScanner へ変えたので、accept メソッドの呼び出し箇所が多少変わっています。

なお、JCTree へキャストしていますが、JCCompilationUnit へキャストしても問題ありません。

src/main/java/sample/DoExprProcessor.java
package sample;

import java.util.Set;
import javax.annotation.processing.*;

import javax.lang.model.SourceVersion;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;

import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.util.Trees;
import com.sun.source.util.TreePath;

import com.sun.tools.javac.processing.JavacProcessingEnvironment;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.util.Context;

@SupportedSourceVersion(SourceVersion.RELEASE_8)
@SupportedAnnotationTypes("*")
public class DoExprProcessor extends AbstractProcessor {
    private Trees trees;
    private Context context;

    @Override
    public void init(ProcessingEnvironment procEnv) {
        trees = Trees.instance(procEnv);
        context = ((JavacProcessingEnvironment)procEnv).getContext();
    }

    @Override
    public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
        roundEnv.getRootElements().stream().map(this::toUnit).forEach(this::processUnit);
        return false;
    }

    private CompilationUnitTree toUnit(Element el) {
        TreePath path = trees.getPath(el);
        return path.getCompilationUnit();
    }

    private void processUnit(CompilationUnitTree cu) {
        if (cu instanceof JCTree) {
            ((JCTree)cu).accept(new DoExprVisitor(context));
            // 変換内容を出力
            System.out.println(cu);
        }
    }
}

TreeVisitor の実装

前回 からの変更点は以下の通りです。

  • (a) コード生成部分を別クラス化
  • (b) 対象のラムダ式JCLambda) を全置換
  • (c) メソッド引数への対応
  • (d) extends 元を com.sun.tools.javac.tree.TreeScanner へ変更 (前回までは com.sun.source.util.TreeScanner

(b) を実現するため changeNode へ置換処理 (JCLambdaJCMethodInvocation へ差し替える事になる) を設定するようにしました。

主な処理内容は次のようになっています。

  • (1) 変数定義(JCVariableDecl)やメソッド実行(JCMethodInvocation)の箇所で該当部分を差し替えるための処理を changeNode へ設定
  • (2) ラムダの内容からソースコードを生成 (対象外なら何もしない)
  • (3) ソースコードJCExpression へパースして (実体は JCMethodInvocationpos の値を調整
  • (4) changeNode を実行しラムダ箇所を差し替え
src/main/java/sample/DoExprVisitor.java
package sample;

import com.sun.tools.javac.parser.ParserFactory;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.*;
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context;

import java.util.function.BiConsumer;
import java.util.stream.Stream;

public class DoExprVisitor extends TreeScanner {
    private ParserFactory parserFactory;
    private BiConsumer<JCLambda, JCExpression> changeNode = (lm, ne) -> {};
    private DoExprBuilder builder = new DoExprBuilder();

    public DoExprVisitor(Context context) {
        parserFactory = ParserFactory.instance(context);
    }

    @Override
    public void visitVarDef(JCVariableDecl node) {
        if (node.init != null) {
            // (b) (1)
            changeNode = (lm, ne) -> {
                // 変数への代入式を置換
                if (node.init == lm) {
                    node.init = ne;
                }
            };
        }
        super.visitVarDef(node);
    }

    // (c)
    @Override
    public void visitApply(JCMethodInvocation node) {
        if (node.args != null && node.args.size() > 0) {
            // (b) (1)
            changeNode = (lm, ne) -> {
                // メソッドの引数部分を置換
                if (node.args.contains(lm)) {
                    Stream<JCExpression> newArgs = node.args.stream().map(a -> (a == lm)? ne: a);
                    node.args = com.sun.tools.javac.util.List.from(newArgs::iterator);
                }
            };
        }
        super.visitApply(node);
    }

    @Override
    public void visitLambda(JCLambda node) {
        // (a) (2)
        builder.build(node).ifPresent(expr -> {
            // (3)
            JCExpression ne = parseExpression(expr);
            fixPos(ne, node.pos);

            // (b) (4) ラムダ部分を差し替え
            changeNode.accept(node, ne);
        });

        super.visitLambda(node);
    }

    // pos 値の修正
    private void fixPos(JCExpression ne, final int basePos) {
        ne.accept(new TreeScanner() {
            @Override
            public void scan(JCTree tree) {
                if(tree != null) {
                    tree.pos += basePos;
                    super.scan(tree);
                }
            }
        });
    }

    // 生成したソースコードをパース
    private JCExpression parseExpression(String doExpr) {
        return parserFactory.newParser(doExpr, false, false, false).parseExpression();
    }
}

コード生成処理の実装

該当のラムダ式を変換したソースコードを生成する処理です。

src/main/java/sample/DoExprBuilder.java
package sample;

import com.sun.tools.javac.tree.JCTree.*;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

public class DoExprBuilder {
    private static final String DO_TYPE = "$do";
    private static final String VAR_PREFIX = "#{";
    private static final String VAR_SUFFIX = "}";
    // let 用のコードテンプレート
    private static final String LET_CODE = "#{var}.bind(#{rExpr}, #{lExpr} -> #{body})";
    // return 用のコードテンプレート
    private static final String RETURN_CODE = "#{var}.unit( #{expr} )";

    private Map<Class<? extends JCStatement>, CodeGenerator<JCStatement>> builderMap = new HashMap<>();

    public DoExprBuilder() {
        // let 用のコード生成
        builderMap.put(JCVariableDecl.class, (n, v, b) -> generateCodeForLet(cast(n), v, b));
        // return 用のコード生成
        builderMap.put(JCReturn.class, (n, v, b) -> generateCodeForReturn(cast(n), v, b));
    }

    public Optional<String> build(JCLambda node) {
        return getDoVar(node).map(var -> createExpression((JCBlock)node.body, var));
    }

    private String createExpression(JCBlock block, String var) {
        String res = "";

        for (JCStatement st : block.stats.reverse()) {
            res = builderMap.getOrDefault(st.getClass(), this::generateNoneCode).generate(st, var, res);
        }
        return res;
    }

    private String generateNoneCode(JCStatement node, String var, String body) {
        return body;
    }

    // let 用のソースコード生成
    private String generateCodeForLet(JCVariableDecl node, String var, String body) {
        String res = body;

        if ("let".equals(node.vartype.toString())) {
            Map<String, String> params = createParams(var);
            params.put("body", res);
            params.put("lExpr", node.name.toString());
            params.put("rExpr", node.init.toString());

            // 入れ子への対応
            if (node.init instanceof JCLambda) {
                JCLambda lm = cast(node.init);

                getDoVar(lm).ifPresent(childVar ->
                        params.put("rExpr", createExpression((JCBlock) lm.body, childVar)));
            }
            res = buildTemplate(LET_CODE, params);
        }

        return res;
    }

    // return 用のソースコード生成
    private String generateCodeForReturn(JCReturn node, String var, String body) {
        Map<String, String> params = createParams(var);
        params.put("expr", node.expr.toString());

        return buildTemplate(RETURN_CODE, params);
    }

    // 処理変数名の抽出
    private Optional<String> getDoVar(JCLambda node) {
        if (node.params.size() == 1) {
            String name = node.params.get(0).name.toString();

            if (name.endsWith(DO_TYPE)) {
                return Optional.of(name.replace(DO_TYPE, ""));
            }
        }
        return Optional.empty();
    }

    private Map<String, String> createParams(String var) {
        Map<String, String> params = new HashMap<>();

        params.put("var", var);

        return params;
    }

    // テンプレート処理
    private String buildTemplate(String template, Map<String, String> params) {
        String res = template;

        for(Map.Entry<String, String> param : params.entrySet()) {
            res = res.replace(VAR_PREFIX + param.getKey() + VAR_SUFFIX, param.getValue());
        }
        return res;
    }

    @SuppressWarnings("unchecked")
    private <S, T> T cast(S obj) {
        return (T)obj;
    }

    private interface CodeGenerator<T> {
        String generate(T node, String var, String body);
    }
}

Service Provider 設定ファイルやビルド定義は 前回 と同じものです。

Service Provider 設定ファイル

src/main/resources/META-INF/services/javax.annotation.processing.Processor
sample.DoExprProcessor
build.gradle
apply plugin: 'java'

def enc = 'UTF-8'
tasks.withType(AbstractCompile)*.options*.encoding = enc

dependencies {
    compile files("${System.properties['java.home']}/../lib/tools.jar")
}

ビルド

ビルド実行
> gradle build

:compileJava
:processResources UP-TO-DATE
:classes
:jar
:assemble
:compileTestJava UP-TO-DATE
:processTestResources UP-TO-DATE
:testClasses UP-TO-DATE
:test UP-TO-DATE
:check UP-TO-DATE
:build

BUILD SUCCESSFUL

ビルド結果として build/libs/java_do_expr.jar3 が生成されます。

動作確認

下記のサンプルコードを使ってアノテーションプロセッサの動作確認を行います。

example/DoExprSample.java
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.Optional;

public class DoExprSample {
    public static void main(String... args) {
        Optional<Integer> o1 = Optional.of(2);
        Optional<Integer> o2 = Optional.of(3);

        Opt<Integer> opt = new Opt<>();

        Optional<Integer> res = opt$do -> {
            let a = o1;
            let b = o2;
            let c = Optional.of(4);
            return a + b + c * 2;
        };
        // Optional[13]
        System.out.println(res);

        Opt<String> opt2 = new Opt<>();

        Optional<String> res2 = opt2$do -> {
            let a = Optional.of("a");
            let b = Optional.of("b");
            let c = opt2$do -> {
                let c1 = Optional.of("c1");
                let c2 = Optional.of("c2");
                return c1 + "-" + c2;
            };
            return a + b + "/" + c;
        };
        // Optional[ab/c1-c2]
        System.out.println(res2);

        // Optional[***ba]
        System.out.println(opt2$do -> {
            let a = Optional.of("a");
            let b = Optional.of("b");
            return "***" + b + a;
        });
    }

    static class Opt<T> {
        public Optional<T> bind(Optional<T> x, Function<T, Optional<T>> f) {
            return x.flatMap(f);
        }

        public Optional<T> unit(T v) {
            return Optional.ofNullable(v);
        }
    }
}

java_do_expr3.jar を使って上記ソースファイルをコンパイルします。

出力内容(変換後のソースコード)を見る限りは変換できているようです。

コンパイル
> javac -cp ../build/libs/java_do_expr3.jar DoExprSample.java

・・・
public class DoExprSample {
    ・・・
    public static void main(String... args) {
        Optional<Integer> o1 = Optional.of(2);
        Optional<Integer> o2 = Optional.of(3);
        Opt<Integer> opt = new Opt<>();
        Optional<Integer> res = opt.bind(o1, (a)->opt.bind(o2, (b)->opt.bind(Optional.of(4), (c)->opt.unit(a + b + c * 2))));
        System.out.println(res);
        Opt<String> opt2 = new Opt<>();
        Optional<String> res2 = opt2.bind(Optional.of("a"), (a)->opt2.bind(Optional.of("b"), (b)->opt2.bind(opt2.bind(Optional.of("c1"), (c1)->opt2.bind(Optional.of("c2"), (c2)->opt2.unit(c1 + "-" + c2))), (c)->opt2.unit(a + b + "/" + c))));
        System.out.println(res2);
        System.out.println(opt2.bind(Optional.of("a"), (a)->opt2.bind(Optional.of("b"), (b)->opt2.unit("***" + b + a))));
    }
    ・・・
}

DoExprSample を実行すると正常に動作しました。

実行結果
> java DoExprSample

Optional[13]
Optional[ab/c1-c2]
Optional[***ba]