Java 8 ラムダ式の実装メソッド名を取得 - SerializedLambda

Java 8 ラムダ式の実装メソッド名を実行時に取得する方法を探ってみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20140413/

はじめに

前回ラムダ式デコンパイルしてみましたが、ラムダ式の処理は lambda$main$0 のような synthetic メソッドに実装されます。

そして、ラムダ式の部分 (下記 f 変数の値) は実行時に java.lang.invoke.LambdaMetafactory.metafactory() で動的にクラス・オブジェクトが生成されます。

Sample0.java
import java.util.function.Predicate;

class Sample0 {
    public static void main(String... args) {
        int a = 3;
        Predicate<Integer> f = (x) -> x % a == 0;
        System.out.println(f.test(6));
    }
}
Sample0 の javap 結果 (lambda$main$0 がラムダ式の実装)
> javap -p Sample0
Compiled from "Sample0.java"
class Sample0 {
  Sample0();
  public static void main(java.lang.String...);
  private static boolean lambda$main$0(int, java.lang.Integer);
}

そこで、実行時に生成されたラムダ式のオブジェクトから実装メソッド名などの情報を取得する事ができるのか調査してみました。

結論としては、通常のラムダ式では無理そうでしたが、シリアライズ可にする事で取得できると分かりました。

  • ラムダ式を Serializable にすると writeReplace メソッドからラムダ式の情報が入った java.lang.invoke.SerializedLambda を取得可能

ただし、writeReplace は private final なメソッドです。(シリアライズのための処理なので)

ちなみに、ラムダ式部分の動的なクラス生成などは java.lang.invoke.InnerClassLambdaMetafactory 内で ASM を使って実施されているようです。

シリアライズ可能なラムダ式から SerializedLambda を取得

それでは、シリアライズ可能なラムダ用のインターフェース (下記の SPredicate) を定義し、writeReplace メソッドを実行して SerializedLambda を取得してみます。

Sample1.java
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.function.Predicate;
import java.io.Serializable;

class Sample1 {
    public static void main(String... args) throws Exception {
        int a = 3;

        SPredicate<Integer> f = (x) -> x % a == 0;

        Method m = f.getClass().getDeclaredMethod("writeReplace");
        m.setAccessible(true);

        // リフレクションで writeReplace を実行し SerializedLambda を取得
        SerializedLambda sl = (SerializedLambda)m.invoke(f);

        System.out.println(sl);

        System.out.println("-----");

        // 実装先クラス名の出力
        System.out.println(sl.getImplClass());
        // 実装先メソッド名の出力
        System.out.println(sl.getImplMethodName());

        // 設定されている引数の出力
        for (int i = 0; i < sl.getCapturedArgCount(); i++) {
            System.out.println("arg: " + sl.getCapturedArg(i));
        }
    }

    // シリアライズ可能にした Predicate の定義
    public interface SPredicate<T> extends Predicate<T>, Serializable {
    }
}

実行してみると、実装メソッド名や設定されている引数(上記 a 変数に該当する 3 の値)を取得できました。

実行結果
> java Sample1
SerializedLambda[capturingClass=class Sample1, 
functionalInterfaceMethod=Sample1$SPredicate.test:(Ljava/lang/Object;)Z, 
implementation=invokeStatic Sample1.lambda$main$50fc8a8$1:(ILjava/lang/Integer;)Z, 
instantiatedMethodType=(Ljava/lang/Integer;)Z, numCaptured=1]
-----
Sample1
lambda$main$50fc8a8$1
arg: 3

CFR でラムダ式の実装メソッドをデコンパイル

SerializedLambda と前回の CFR を使えば、実行時にラムダ式の実装メソッドをデコンパイルしたりする事も可能です。

Sample2.java
import java.lang.invoke.SerializedLambda;
import java.util.function.Predicate;
import java.io.Serializable;

import org.benf.cfr.reader.util.getopt.GetOptParser;
import org.benf.cfr.reader.util.getopt.Options;
import org.benf.cfr.reader.util.getopt.OptionsImpl;
import org.benf.cfr.reader.entities.ClassFile;
import org.benf.cfr.reader.entities.Method;
import org.benf.cfr.reader.state.DCCommonState;
import org.benf.cfr.reader.util.output.ToStringDumper;

class Sample2 {
    public static void main(String... args) throws Exception {
        int a = 3;

        SPredicate<Integer> f = (x) -> x % a == 0;

        java.lang.reflect.Method m = f.getClass().getDeclaredMethod("writeReplace");
        m.setAccessible(true);

        SerializedLambda sl = (SerializedLambda)m.invoke(f);

        String src = decompileLambda(sl);

        System.out.println(src);
    }

    // ラムダ式の実装メソッドをデコンパイル
    private static String decompileLambda(SerializedLambda sl) throws Exception {
        ToStringDumper d = new ToStringDumper();

        Options options = new GetOptParser().parse(new String[] {sl.getImplClass()}, OptionsImpl.getFactory());
        DCCommonState dcCommonState = new DCCommonState(options);

        ClassFile c = dcCommonState.getClassFileMaybePath(options.getFileName());
        c = dcCommonState.getClassFile(c.getClassType());

        for (Method m : c.getMethodByName(sl.getImplMethodName())) {
            m.dump(d, true);
        }

        return d.toString();
    }

    public interface SPredicate<T> extends Predicate<T>, Serializable {
    }
}

実行結果は下記の通り、ラムダ式の実装メソッド(lambda$main$50fc8a8$1)のソースが出力されます。

実行結果
> java -cp .;cfr_0_78.jar Sample2
private static /* synthetic */ boolean lambda$main$50fc8a8$1(int n, java.lang.Integer n2) {
    return n2 % n == 0;
}

上記を工夫すれば groovy.sql.DataSet のような O-R マッピング処理(クロージャSQL の where 部分を定義)をラムダ式で実現できると思います。

CFR で Java 8 のラムダ式をデコンパイルする

Java 8 のラムダ式にも対応した CFR という Java のデコンパイラをご紹介します。

使い方

使い方は簡単で、http://www.benf.org/other/cfr/ から JAR ファイルをダウンロードして下記のように実行するだけです。

java -jar cfr_0_78.jar <Java クラスファイル> [メソッド名] [オプション]
java -jar cfr_0_78.jar <JAR ファイル> [オプション]

ラムダ式デコンパイル

それでは、下記ソースをコンパイルして出来た LambdaSample.class を CFR でデコンパイルしてみます。

LambdaSample.java
import java.util.function.Predicate;

class LambdaSample {
    public static void main(String... args) {
        Predicate<Integer> f = (x) -> x > 10;

        System.out.println("5 : " + f.test(5));
        System.out.println("15 : " + f.test(15));
    }
}

まずは、オプションを全く指定せずにデコンパイルしてみます。 ラムダ式の部分も見事にデコンパイルされたソースコードが標準出力に出力されます。

デコンパイル結果(オプション指定なし)
> java -jar cfr_0_78.jar LambdaSample.class

/*
 * Decompiled with CFR 0_78.
 */
import java.io.PrintStream;
import java.util.function.Predicate;

class LambdaSample {
    LambdaSample() {
    }

    public static /* varargs */ void main(String ... arrstring) {
        Predicate<Integer> predicate = n -> n > 10;
        System.out.println("5 : " + predicate.test(5));
        System.out.println("15 : " + predicate.test(15));
    }
}

次に、--decodelambdas false オプションを指定してラムダ式の部分をデコンパイルしないようにしてみます。

こうする事で、ラムダ式へ展開されず synthetic のメソッド定義 (ラムダ式の実体)と LambdaMetafactory.metafactory() 呼び出し処理のソースが出力されます。

デコンパイル結果(--decodelambdas false)
> java -jar cfr_0_78.jar LambdaSample.class --decodelambdas false

/*
 * Decompiled with CFR 0_78.
 */
import java.io.PrintStream;
import java.util.function.Predicate;

class LambdaSample {
    LambdaSample() {
    }

    public static /* varargs */ void main(String ... arrstring) {
        Predicate<Integer> predicate = (Predicate<Integer>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Z, lambda$main$0(java.lang.Integer ), (Ljava/lang/Integer;)Z)();
        System.out.println("5 : " + predicate.test(5));
        System.out.println("15 : " + predicate.test(15));
    }

    private static /* synthetic */ boolean lambda$main$0(Integer n) {
        return n > 10;
    }
}

Javaの列挙型(Enum)へ新しい要素を追加2 - Javassist

前回Java の列挙型(Enum)へ新しい要素(識別子)を追加するためリフレクションを駆使しましたが、今回は Javassist を使ってもっと容易に実現する方法をご紹介します。

ソースは http://github.com/fits/try_samples/tree/master/blog/20140316/

はじめに

前回と同様の列挙型へ Second 要素を追加してみる事にします。

EType.java
enum EType {
    First
}

Javassist を使った列挙型への要素追加

Javassist を使う場合、下記を実施して $VALUES フィールドへ列挙型の要素を好きなように設定するだけです。

  • CtConstructorinsertAfter メソッドを使って、静的初期化子へ $VALUES を書き換える処理を挿入

なお、静的初期化子のための CtConstructor オブジェクトは CtClass オブジェクトの getClassInitializerメソッドで取得します。

EnumAddValueJavassist.java
import javassist.*;

public class EnumAddValueJavassist {
    public static void main(String... args) throws Exception {
        ClassPool pool = ClassPool.getDefault();

        CtClass cc = pool.get("EType");

        // 静的初期化子(static イニシャライザ)へ $VALUES を変更する処理を追加
        cc.getClassInitializer().insertAfter("$VALUES = new EType[] { First, new EType(\"Second\", 1) };");

        cc.toClass();

        System.out.println(EType.valueOf("Second"));

        System.out.println("-----");

        for (EType type : EType.values()) {
            System.out.println(type);
        }
    }
}

実行すると下記のように Second の追加を確認できます。

ビルドと実行
> javac -cp .;javassist-3.18.1-GA.jar *.java

> java -cp .;javassist-3.18.1-GA.jar EnumAddValueJavassist
Second
-----
First
Second

実際は、下記のように Second クラスフィールドを追加しておいた方が望ましいかもしれません。

EnumAddValueJavassist2.java
import javassist.*;

public class EnumAddValueJavassist2 {
    public static void main(String... args) throws Exception {
        ClassPool pool = ClassPool.getDefault();

        CtClass cc = pool.get("EType");

        // Second フィールドの追加
        CtField second = CtField.make("public static final EType Second = new EType(\"Second\", 1);", cc);
        cc.addField(second);

        cc.getClassInitializer().insertAfter("$VALUES = new EType[] { First, Second };");

        cc.toClass();
        ・・・
    }
}

Javaの列挙型(Enum)へ新しい要素を追加

Java の列挙型(Enum)へ新しい要素(識別子)を動的に追加する方法を探ってみました。

列挙型の場合、普通のリフレクションクラスではインスタンス化できませんので、下記のように sun パッケージのクラスを使用する必要があります。

  • (1) sun.reflect.ConstructorAccessor で列挙型の新しい要素をインスタンス
  • (2) sun.misc.Unsafe で列挙型の $VALUES フィールドへ (1) のインスタンスを追加

そのため、下記の環境では動作確認できましたが、他の Java 実行環境では使えないかもしれません。

  • Java SE 7 (1.7.0_51-b13)
  • Java SE 8 (1.8.0-b129)

ソースは http://github.com/fits/try_samples/tree/master/blog/20140309/

はじめに

今回は下記のような列挙型へ Second と Third の要素を追加する事にします。

enum EType {
    First
}

Constructor の newInstance では列挙型をインスタンス化できない

下記のように ConstructornewInstance メソッドで列挙型(今回の EType)の新しい要素をインスタンス化したいところなのですが、IllegalArgumentException エラーとなります。

EnumAddValueError.java
import java.lang.reflect.Constructor;

public class EnumAddValueError {
    public static void main(String... args) throws Exception {
        // EType のコンストラクタ(private)取得
        Constructor<EType> cls = EType.class.getDeclaredConstructor(String.class, int.class);
        // private コンストラクタを実行するための設定
        cls.setAccessible(true);

        // 列挙型は newInstance できないのでエラー
        EType t2 = cls.newInstance("Second", 1);
    }

    enum EType {
        First
    }
}
実行例
> java EnumAddValueError
Exception in thread "main" java.lang.IllegalArgumentException: Cannot reflectively create enum objects
        at java.lang.reflect.Constructor.newInstance(Constructor.java:521)
        at EnumAddValueError.main(EnumAddValueError.java:11)

これは newInstance メソッドの実装内容が原因です。

java.lang.reflect.Constructor のソース(一部抜粋)
public final class Constructor<T> extends Executable {
    ・・・
    @CallerSensitive
    public T newInstance(Object ... initargs)
        throws InstantiationException, IllegalAccessException,
               IllegalArgumentException, InvocationTargetException
    {
        ・・・
        if ((clazz.getModifiers() & Modifier.ENUM) != 0)
            throw new IllegalArgumentException("Cannot reflectively create enum objects");
        ConstructorAccessor ca = constructorAccessor;   // read volatile
        if (ca == null) {
            ca = acquireConstructorAccessor();
        }
        @SuppressWarnings("unchecked")
        T inst = (T) ca.newInstance(initargs);
        return inst;
    }
    ・・・
}

ここで、sun.reflect.ConstructorAccessor を直接使えば IllegalArgumentException を避けてインスタンス化できる事も分かります。

実装

それでは、列挙型へ新しい要素を追加する処理を実装してみます。

(1) sun.reflect.ConstructorAccessor で列挙型の新しい要素をインスタンス

まずは、列挙型(下記の EType)の新しいインスタンスを作成する必要がありますが、前述したように sun.reflect.ConstructorAccessor を使う必要があります。

ここで、どのようにして列挙型の ConstructorAccessor を入手するかが課題となりますが、今回は Constructor の acquireConstructorAccessor メソッドを使って取得してみました。

acquireConstructorAccessor は private メソッドなのでリフレクションを使って実行します。

列挙型の ConstructorAccessor が手に入れば、newInstance メソッドを実行して列挙型の新しいインスタンスを得る事ができます。

EnumAddValue.java (列挙型のインスタンス化)
・・・
import sun.reflect.ConstructorAccessor;

public class EnumAddValue {
    public static void main(String... args) throws Exception {
        EType t2 = addEnumValue(EType.class, "Second", 1);
        ・・・
    }

    // (1) sun.reflect.ConstructorAccessor で列挙型の新しい要素をインスタンス化
    private static <T extends Enum<?>> T addEnumValue(Class<T> enumClass, String name, int ordinal) throws Exception {
        // acquireConstructorAccessor メソッド
        Method m = Constructor.class.getDeclaredMethod("acquireConstructorAccessor");
        m.setAccessible(true);

        // 列挙型のコンストラクタ取得
        Constructor<T> cls = enumClass.getDeclaredConstructor(String.class, int.class);
        // acquireConstructorAccessor を実行し ConstructorAccessor を取得
        ConstructorAccessor ca = (ConstructorAccessor)m.invoke(cls);

        // 列挙型の新しい要素をインスタンス化
        @SuppressWarnings("unchecked")
        T result = (T)ca.newInstance(new Object[]{name, ordinal});

        // (2) sun.misc.Unsafe で列挙型の $VALUES フィールドへ (1) のインスタンスを追加
        addValueToEnum(result);

        return result;
    }
    ・・・
    enum EType {
        First
    }
}

(2) sun.misc.Unsafe で列挙型の $VALUES フィールドへ (1) のインスタンスを追加

(1) で列挙型の新しい要素をインスタンス化できるようになりましたが、これだけでは不十分です。

valueOf メソッド等を使えるようにするには、列挙型の $VALUES クラスフィールド(private static final)へ新しいインスタンスを追加する必要があります。

private final なクラスフィールドの内容を強引に変更するには sun.misc.Unsafe クラスを使用する事になります。

ここで、Unsafe のインスタンスUnsafe.getUnsafe() で取得したいところですが、今回のやり方だと SecurityException エラーとなってしまいます。

SecurityException を回避するのは面倒そうだったので、今回はリフレクションを使って Unsafe の theUnsafe クラスフィールド(private static final)から Unsafe インスタンスを取得してみました。

Unsafe のインスタンスを得られれば putObjectVolatile メソッド等で private final なクラスフィールドの内容を変更できます。

EnumAddValue.java ($VALUES への要素追加)
・・・
    // (2) sun.misc.Unsafe で列挙型の $VALUES フィールドへ (1) のインスタンスを追加
    private static <T extends Enum<?>> void addValueToEnum(T newValue) throws Exception {
        // $VALUES フィールド取得
        Field f = newValue.getClass().getDeclaredField("$VALUES");
        f.setAccessible(true);
        // $VALUES の値を取得
        @SuppressWarnings("unchecked")
        T[] values = (T[])f.get(null);

        T[] newValues = Arrays.copyOf(values, values.length + 1);
        // 列挙型の新しい要素を追加
        newValues[values.length] = newValue;

        // theUnsafe フィールド
        Field uf = Unsafe.class.getDeclaredField("theUnsafe");
        uf.setAccessible(true);

        // theUnsafe フィールドから Unsafe インスタンスを取得
        Unsafe unsafe = (Unsafe)uf.get(null);

        // $VALUES フィールドへ値を設定
        unsafe.putObjectVolatile(unsafe.staticFieldBase(f), unsafe.staticFieldOffset(f), newValues);
    }
・・・

実行

今回作成したソースの全容は下記の通りです。

EnumAddValue.java
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;

import sun.misc.Unsafe;
import sun.reflect.ConstructorAccessor;

public class EnumAddValue {
    public static void main(String... args) throws Exception {
        EType t2 = addEnumValue(EType.class, "Second", 1);
        System.out.println(t2);

        EType t3 = addEnumValue(EType.class, "Third", 2);
        System.out.println(EType.valueOf("Third"));
        System.out.println("Thrid == t3 : " + (EType.valueOf("Third") == t3));

        System.out.println("-----");

        for (EType type : EType.values()) {
            System.out.println(type);
        }
    }

    // (1) sun.reflect.ConstructorAccessor で列挙型の新しい要素をインスタンス化
    private static <T extends Enum<?>> T addEnumValue(Class<T> enumClass, String name, int ordinal) throws Exception {
        Method m = Constructor.class.getDeclaredMethod("acquireConstructorAccessor");
        m.setAccessible(true);

        Constructor<T> cls = enumClass.getDeclaredConstructor(String.class, int.class);
        ConstructorAccessor ca = (ConstructorAccessor)m.invoke(cls);

        @SuppressWarnings("unchecked")
        T result = (T)ca.newInstance(new Object[]{name, ordinal});

        addValueToEnum(result);

        return result;
    }

    // (2) sun.misc.Unsafe で列挙型の $VALUES フィールドへ (1) のインスタンスを追加
    private static <T extends Enum<?>> void addValueToEnum(T newValue) throws Exception {
        Field f = newValue.getClass().getDeclaredField("$VALUES");
        f.setAccessible(true);

        @SuppressWarnings("unchecked")
        T[] values = (T[])f.get(null);

        T[] newValues = Arrays.copyOf(values, values.length + 1);
        newValues[values.length] = newValue;

        Field uf = Unsafe.class.getDeclaredField("theUnsafe");
        uf.setAccessible(true);

        Unsafe unsafe = (Unsafe)uf.get(null);

        unsafe.putObjectVolatile(unsafe.staticFieldBase(f), unsafe.staticFieldOffset(f), newValues);
    }

    enum EType {
        First
    }
}

ビルドすると ConstructorAccessor・Unsafe を使っている事に対する警告が出ます。

ビルド
> javac EnumAddValue.java

EnumAddValue.java:7: 警告: Unsafeは内部所有のAPIであり、今後のリリースで削除される可能性があります
import sun.misc.Unsafe;
・・・
警告7個

実行してみると、一応 Second と Third を追加できている事を確認できました。

実行
> java EnumAddValue
Second
Third
Thrid == t3 : true
-----
First
Second
Third

R でロジスティック回帰とオッズ比の算出 - glm, MCMClogit

以前、glm・MCMCmetrop1R 関数でロジスティック回帰を試みましたが、今回はその時に利用を断念した MCMCpack の MCMClogit 関数を使ってロジスティック回帰を行います。

題材は、書籍 「 データサイエンティスト養成読本 [ビッグデータ時代のビジネスを支えるデータ分析力が身につく! ] (Software Design plus) 」の p.38 と同様の Titanic データセットを使ったロジスティック回帰とオッズ比の算出です。

以前使ったデータは主に数値でしたが、今回は主に因子(factor)データを扱っている点が異なっています。

今回のソースは http://github.com/fits/try_samples/tree/master/blog/20140302/

はじめに

MCMCpack パッケージを R へインストールしておきます。

install.packages("MCMCpack")

Titanic データセット

Titanic のデータセットは R にデフォルトで用意されており、data.frame(Titanic) すると下記のようになります。

data.frame(Titanic) 結果
> data.frame(Titanic)

   Class    Sex   Age Survived Freq
1    1st   Male Child       No    0
2    2nd   Male Child       No    0
3    3rd   Male Child       No   35
4   Crew   Male Child       No    0
5    1st Female Child       No    0
6    2nd Female Child       No    0
7    3rd Female Child       No   17
8   Crew Female Child       No    0
9    1st   Male Adult       No  118
10   2nd   Male Adult       No  154
11   3rd   Male Adult       No  387
12  Crew   Male Adult       No  670
13   1st Female Adult       No    4
14   2nd Female Adult       No   13
15   3rd Female Adult       No   89
16  Crew Female Adult       No    3
17   1st   Male Child      Yes    5
18   2nd   Male Child      Yes   11
19   3rd   Male Child      Yes   13
20  Crew   Male Child      Yes    0
・・・

内容は以下の通りです。

Class(船室等級) Sex(性別) Age(年齢層) Survived(生存可否) Freq(人数)
1st, 2nd, 3rd, Crew Male, Female Child, Adult No, Yes 数値

今回は、ロジスティック回帰の結果から船室等級・性別・年齢層毎の生存率に対するオッズ比を算出します。

(1) glm を使ったロジスティック回帰とオッズ比

書籍の内容とほとんど同じですが、 まずは glm 関数を使ったロジスティック回帰です。

Survived~. としているので Class・Sex・Age を説明変数としたロジスティック回帰を実施します。

オッズ比は exp(<推定値>) で算出できるので、書籍のような epicalc パッケージは使わず、glm 結果の $coefficientsexp 関数へ渡してオッズ比を算出しました。

logiMcmcglmm.R
d <- data.frame(Titanic)

d.data <- data.frame(
  Class = rep(d$Class, d$Freq),
  Sex = rep(d$Sex, d$Freq),
  Age = rep(d$Age, d$Freq),
  Survived = rep(d$Survived, d$Freq)
)

d.res <- glm(Survived~., data = d.data, family = binomial)
summary(d.res)

# オッズ比
exp(d.res$coefficients)

なお、上記では data.frame(Titanic) の Freq(人数) を展開したデータフレーム(下記)を使ってロジスティック回帰を実施しています。

d.data の内容
   Class    Sex   Age Survived
1    3rd   Male Child       No
2    3rd   Male Child       No
・・・
35   3rd   Male Child       No
36   3rd Female Child       No
・・・

実行結果は下記の通りです。

実行結果
・・・
> d.res <- glm(Survived~., data = d.data, family = binomial)
> summary(d.res)

Call:
glm(formula = Survived ~ ., family = binomial, data = d.data)

Deviance Residuals: 
    Min       1Q   Median       3Q      Max  
-2.0812  -0.7149  -0.6656   0.6858   2.1278  

Coefficients:
            Estimate Std. Error z value Pr(>|z|)    
(Intercept)   0.6853     0.2730   2.510   0.0121 *  
Class2nd     -1.0181     0.1960  -5.194 2.05e-07 ***
Class3rd     -1.7778     0.1716 -10.362  < 2e-16 ***
ClassCrew    -0.8577     0.1573  -5.451 5.00e-08 ***
SexFemale     2.4201     0.1404  17.236  < 2e-16 ***
AgeAdult     -1.0615     0.2440  -4.350 1.36e-05 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 2769.5  on 2200  degrees of freedom
Residual deviance: 2210.1  on 2195  degrees of freedom
AIC: 2222.1

Number of Fisher Scoring iterations: 4

> 
> # オッズ比
> exp(d.res$coefficients)
(Intercept)    Class2nd    Class3rd   ClassCrew   SexFemale    AgeAdult 
  1.9844057   0.3612825   0.1690159   0.4241466  11.2465380   0.3459219 

オッズ比から、女性(Female)は男性(Male)の 11倍(11.2465380)、大人(Adult)は子供(Child)の 3分の1 (0.3459219)の生存率という結果になりました。

(2) MCMClogit を使ったロジスティック回帰とオッズ比

それでは、MCMCpack の MCMClogit 関数を使って同様の処理を実施してみます。

glm 関数の場合とほとんど同じですが、MCMClogit 関数は応答変数(下記の Survived)に因子型(factor)を扱えないようなので、Survived の No と Yes を 0(= No) と 1(= Yes)へ変換しています。 (as.numeric(d$Survived) - 1 の箇所)

また、MCMClogit 関数の結果も glm とは違って値の分布となるので summary した結果の平均値(期待値) Mean を使ってオッズ比を算出しています。

logistic_odds_mcmclogit.R
library(MCMCpack)

d <- data.frame(Titanic)

d.data <- data.frame(
  Class = rep(d$Class, d$Freq),
  Sex = rep(d$Sex, d$Freq),
  Age = rep(d$Age, d$Freq),
  Survived = rep(as.numeric(d$Survived) - 1, d$Freq)
)

d.res <- MCMClogit(Survived~., data = d.data)

d.summ <- summary(d.res)

d.summ

#オッズ比
exp(d.summ$statistics[, "Mean"])

実行結果は下記の通りです。

実行結果
・・・
> d.res <- MCMClogit(Survived~., data = d.data)
> 
> d.summ <- summary(d.res)
> 
> d.summ

Iterations = 1001:11000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 10000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

               Mean     SD Naive SE Time-series SE
(Intercept)  0.6775 0.2682 0.002682       0.012059
Class2nd    -1.0367 0.1961 0.001961       0.008828
Class3rd    -1.7935 0.1768 0.001768       0.008054
ClassCrew   -0.8607 0.1596 0.001596       0.006971
SexFemale    2.4486 0.1384 0.001384       0.006279
AgeAdult    -1.0540 0.2392 0.002392       0.010882

2. Quantiles for each variable:

               2.5%     25%     50%     75%   97.5%
(Intercept)  0.1455  0.4992  0.6831  0.8630  1.2028
Class2nd    -1.4185 -1.1683 -1.0331 -0.9027 -0.6626
Class3rd    -2.1481 -1.9130 -1.7926 -1.6760 -1.4476
ClassCrew   -1.1764 -0.9661 -0.8651 -0.7535 -0.5488
SexFemale    2.1543  2.3600  2.4525  2.5442  2.7110
AgeAdult    -1.5275 -1.2144 -1.0573 -0.8904 -0.5907

> 
> #オッズ比
> exp(d.summ$statistics[, "Mean"])
(Intercept)    Class2nd    Class3rd   ClassCrew   SexFemale    AgeAdult 
  1.9690188   0.3546137   0.1663746   0.4228566  11.5715841   0.3485381 

glm と同じような結果となりました。

最後に、それぞれの値の分布は下記のようになりました。

f:id:fits:20140302174258p:plain

f:id:fits:20140302174314p:plain

Windows上で Rust を使用

今回は Windows 上で Rust を使ってみます。

Rust はトレイト・パターンマッチ・アトリビュート等のモダンな言語機能を持ち、オブジェクト指向と純粋関数型のプログラミングスタイルをサポートしている、なかなか興味深いプログラミング言語です。

環境構築

まずは Rust のビルド・実行環境を構築します。
といっても下記を実施するだけです。

  • (1) Rust をインストール
  • (2) MinGWgcc をインストール

(2) に関しては、Haskell Platform 2013.2.0.0 for Windows 等をインストールしてあれば、その中に含まれている mingw を代わりに使えば良いので必要ありません。

また、Rust 0.9 では GCC に依存しているので (2) が必要となっていますが、将来的には変わるかもしれません。

(1) Rust をインストール

Using Rust on Windows の installer リンクから rust-0.9-install.exe をダウンロード、インストールします。

(2) MinGWgcc をインストール

http://sourceforge.net/projects/mingw/files/ から mingw-get-setup.exe をダウンロード、インストールします。

also install support for the graphical user interface のチェックを外してインストールした場合、「Quit」 ボタンを押下してインストールを終了する点にご注意ください。(この場合 「Continue」 ボタンは有効になりません)

次に MinGWgcc をインストールします。

コマンドラインでインストールする場合、mingw-get.exe (例 C:\MinGW\bin\mingw-get.exe) を使って gcc をインストールします。

MinGWgcc をインストール
> mingw-get install gcc

なお、Haskell Platform 2013.2.0.0 for Windows 等をインストールしてあれば、MinGW のインストールは不要です。

ビルドと実行

それでは下記のようなサンプルソースをビルドして実行してみます。

実行時の処理は main() 関数へ実装します。

sample.rs
fn main() {
    let d1 = Data { name: ~"data", value: 10 };
    let d2 = Data { name: ~"data", value: 10 };
    let d3 = Data { name: ~"data", value:  0 };
    let d4 = Data { name: ~"etc",  value:  5 };

    println!("d1 == d2 : {}", d1 == d2);
    println!("d1 == d2 : {}", d1.eq(&d2));
    println!("d1 == d3 : {}", d1 == d3);

    println!("-----")

    println!("{:?}", d1);
    println!("{}", d1.to_str());

    println!("-----")

    println!("times = {}", d1.times(3));

    println!("-----")

    d1.printValue();
    d3.printValue();

    println!("-----")

    let res = calc([d1, d2, d3, d4]);
    println!("calc = {}", res);
}

fn calc(list: &[Data]) -> int {
    list.iter().fold(1, |acc, v| acc * match v {
        // name = "data" で value の値が 0 より大きい場合
        &Data {name: ~"data", value: b} if b > 0 => b,
        // それ以外
        _ => 1
    })
}

// Eq と ToStr トレイトを自動導出した struct 型の定義
#[deriving(Eq, ToStr)]
struct Data {
    name: ~str,
    value: int
}

// Data のメソッド定義と実装
impl Data {
    fn printValue(&self) {
        match self.value {
            0 => println!("value: zero"),
            a @ _ => println!("value: {}", a)
        }
    }
}

// トレイト定義
trait Sample {
    fn get_value(&self) -> int;

    fn times(&self, n: int) -> int {
        self.get_value() * n
    }
}

// Data へ Sample トレイトを実装
impl Sample for Data {
    fn get_value(&self) -> int {
        self.value
    }
}

上記では deriving アトリビュートを使って Eq と ToStr トレイトを自動導出しています。 Eq の deriving で eq() メソッドが自動的に実装され == が使えるようになり、ToStr の deriving で to_str() メソッドが自動的に実装されます。

なお、struct 型(構造体)は {} で出力できないので、代わりに {:?} を使うか std::fmt::Default トレイトを実装する必要があります。

ちなみに、~ はポインタ(Owning pointers) で & はリファレンスです。

ビルド

まずは環境変数 PATH へ Rust と MinGW の bin ディレクトリのパスを設定しておきます。

環境変数 PATH の設定例1
set PATH=C:\Rust\bin;C:\MinGW\bin

MinGW の代わりに Haskell Platform 2013.2.0.0 for Windows を使う場合は下記のようになります。

環境変数 PATH の設定例2
set PATH=C:\Rust\bin;C:\Haskell Platform\2013.2.0.0\mingw\bin

rustc コマンドを使ってビルドを実施します。 ビルドに成功すると .exe ファイルが作成されます。

ビルド例
> rustc sample.rs

実行

ビルドで生成された .exe ファイルを実行します。

実行には Rust の bin ディレクトリ内の .dll ファイルを使いますので、Rust の bin は環境変数 PATH へ設定しておく必要があります。 (MinGW は基本的に不要です)

実行例
> sample.exe
d1 == d2 : true
d1 == d2 : true
d1 == d3 : false
-----
Data{name: ~"data", value: 10}
Data{name: data, value: 10}
-----
times = 30
-----
value: 10
value: zero
-----
calc = 100

今回使用したサンプルソースhttp://github.com/fits/try_samples/tree/master/blog/20140223/

R で個体差のあるロジスティック回帰2 - MCMCglmm

前回 に続き、今回は個体差を考慮したロジスティック回帰を MCMCglmm で試してみます。

実は MCMCglmm の他にも MCMCpack の MCMChlogit や bayesm の rhierBinLogit 等といろいろ試そうとしてみたのですが、イマイチ使い方が分からなかったので今回は断念しました。

今回使用したサンプルソースhttp://github.com/fits/try_samples/tree/master/blog/20140211/

はじめに

今回使用するパッケージを R へインストールしておきます。

install.packages("MCMCglmm")

(2) MCMCglmm を使った階層ベイズモデルによるロジスティック回帰(MCMCglmm 関数)

それでは、前回 と同じデータ (data7.csv) を使って推定と予測線のグラフ化を実施してみます。

MCMCglmm() による推定

MCMCglmm() 関数は glmmML() 関数とほぼ同じ使い方ができますが、 今回のケースでは下記の点が異なります。

  • family に multinomial2 を指定する
  • デフォルトで個体差 (units に該当) が考慮されるようなので特に id 列を指定する必要はない

デフォルトで詳細(処理状況)が出力されるので、下記では出力しないよう verbose = FALSE としています。

logiMcmcglmm.R
d <- read.csv('data7.csv')

library(MCMCglmm)

d.res <- MCMCglmm(cbind(y, N - y) ~ x, data = d, family = "multinomial2", verbose = FALSE)

summary(d.res)
・・・

実行結果は下記のようになります。

実行結果
> d.res <- MCMCglmm(cbind(y, N - y) ~ x, data = d, family = "multinomial2", verbose = FALSE)
> 
> summary(d.res)

 Iterations = 3001:12991
 Thinning interval  = 10
 Sample size  = 1000 

 DIC: 667.8652 

 R-structure:  ~units

      post.mean l-95% CI u-95% CI eff.samp
units     6.988    3.898    10.58    517.4

 Location effects: cbind(y, N - y) ~ x 

            post.mean l-95% CI u-95% CI eff.samp  pMCMC    
(Intercept)   -4.2271  -6.3586  -2.5264    814.3 <0.001 ***
x              1.0144   0.6117   1.5237    800.3 <0.001 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

各パラメータの平均値 (post.mean) が { \bar \beta_1 = -4.2271, \bar \beta_2 = 1.0144, \bar s = \sqrt{6.988} = 2.643 } で、前回の結果(glmmML)と近い値になりました。

ここで units は前回の個体差パラメータ { r_i } の分散 { s ^2 } に該当するようですので、標準偏差とするために平方根を取りました。

プログラム上で上記の結果を使う場合は、MCMCglmm() の結果から { \beta_1, \beta_2 } (Intercept, x)の分布を $Sol で、{ s ^2 } (units)の分布を $VCV で取得できます。

なお、実行する度に若干異なる値になるようなので、何らかのオプションを指定して調整する必要があるのかもしれません。

葉数と生存種子数

MCMCglmm() の結果を使って 前回 と同様にグラフ化してみます。

今回はパラメータの平均値(post.mean の値)を使わずに、事後最頻値 (posterior mode) というものを使って予測線を描画してみました。

事後最頻値は posterior.mode() 関数で取得できます。

logiMcmcglmm.R
・・・
# 生存確率の算出
calcProb <- function(x, b, r)
    1.0 / (1.0 + exp(-1 * (b[1] + b[2] * x + r)))

png("logiMcmcglmm_1.png")

plot(d$x, d$y)

xx <- seq(min(d$x), max(d$x), length = 100)

# beta1 と beta2 の事後最頻値
beta <- posterior.mode(d.res$Sol)
# s の事後最頻値
sd <- sqrt(posterior.mode(d.res$VCV))

lines(xx, max(d$N) * calcProb(xx, beta, 0), col="green")
lines(xx, max(d$N) * calcProb(xx, beta, -1 * sd), col="blue")
lines(xx, max(d$N) * calcProb(xx, beta, sd), col="blue")

dev.off()
・・・

f:id:fits:20140211122143p:plain

ついでに、前々回 のようにpredict() を使って予測線を描画しようとしてみましたが、今のところ predict.MCMCglmm() は新しいデータの指定に対応していないようです。

> predict(d.res, data.frame(x = xx))

Error in predict.MCMCglmm(d.res, data.frame(x = xx)) : 
  sorry newdata not implemented yet

葉数 { x_i = 4 } の種子数分布

こちらも事後最頻値を使って、前回 と同様にグラフ化してみます。

logiMcmcglmm.R
・・・
png("logiMcmcglmm_2.png")

yy <- 0:max(d$N)

plot(yy, table(d[d$x == 4,]$y), xlab="y", ylab="num")

# 葉数 x を固定した場合の生存種子数 y の確率分布を算出
calcL <- function(ylist, xfix, n, b, s)
  sapply(ylist, function(y) integrate(
      f = function(r) dbinom(y, n, calcProb(xfix, b, r)) * dnorm(r, 0, s),
      lower = s * -10,
      upper = s * 10
    )$value
  )

lines(yy, calcL(yy, 4, max(d$N), beta, sd) * length(d[d$x == 4,]$y), col="red", type="b")

dev.off()

f:id:fits:20140211122201p:plain