ジニ不純度の算出 - Groovy, Scala , Java 8, Frege

書籍 「集合知プログラミング」 の 「7章 決定木によるモデリング」 にあったジニ不純度(ジニ係数)の計算を下記の JVM 言語で関数言語的に実装してみました。

今回のソースは http://github.com/fits/try_samples/tree/master/blog/20140601/

はじめに

ジニ不純度の算出は、下記の (1) と (2) で異なるアイテムを取り出す確率を求める事になります。

  • (1) ある集合から 1つアイテムを取り出す
  • (2) 取り出したアイテムを戻して再度 1つアイテムを取り出す

例えば、["A", "B", "B", "C", "B", "A"] のような集合の場合に A、B、C を取り出す確率はそれぞれ以下のようになります。

A = 2/6 = 1/3
B = 3/6 = 1/2
C = 1/6 

ここで、ジニ不純度は以下のような 2通りの方法で算出できます。 (下記の XY は (1) で X が出て (2) で Y が出る確率を表している)

(a) ジニ不純度 = 1 - (AA + BB + CC) = 1 - (1/3 × 1/3 + 1/2 × 1/2 + 1/6 × 1/6) = 11/18 = 0.61
(b) ジニ不純度 = AB + AC + BA + BC + CA + CB = 1/3 × 1/2 + 1/3 × 1/6 + ・・・ = 11/18 = 0.61

(a) の方がシンプルな実装になると思います。

Groovy で実装

それではそれぞれの言語で実装してみます。

Groovy では countBy メソッドで要素毎のカウント値を簡単に取得できます。

異なる要素同士の組み合わせは、今回 nCopiescombinationsfindAll を使って取得しました。

下記で list.countBy {it} の結果は [A:2, B:3, C:1]
nCopies(2, list.countBy { it }) の結果は [[A:2, B:3, C:1], [A:2, B:3, C:1]]
nCopies(2, list.countBy { it }).combinations() の結果は [[A=2, A=2], [B=3, A=2], ・・・, [B=3, C=1], [C=1, C=1]]
となります。

gini.groovy
import static java.util.Collections.nCopies

// (a) 1 - (AA + BB + CC)
def giniA = { xs ->
    1 - xs.countBy { it }*.value.sum { (it / xs.size()) ** 2 }
}

// (b) AB + AC + BA + BC + CA + CB
def giniB = { xs ->
    nCopies(2, xs.countBy { it }).combinations().findAll {
        // 同じ要素同士の組み合わせを除外
        it.first().key != it.last().key
    }.sum {
        (it.first().value / xs.size()) * (it.last().value / xs.size())
    }
}

def list = ['A', 'B', 'B', 'C', 'B', 'A']

println giniA(list)
println giniB(list)
実行結果
> groovy gini.groovy

0.61111111112222222222
0.61111111112222222222

Scala で実装

Scala では Groovy の countBy に該当するメソッドが無さそうだったので groupBy を使いました。

List で combinations(2) とすればリスト内要素の 2要素の組み合わせ (下記では AB、AC、BC の 3種類の組み合わせ) を取得できます。

下記で list.groupBy(identity) の結果は Map(A -> List(A, A), C -> List(C), B -> List(B, B, B)) となります。

gini.scala
import scala.math.pow

// (a) 1 - (AA + BB + CC)
val giniA = (xs: List[_]) => 1 - xs.groupBy(identity).mapValues( v =>
    pow(v.size.toDouble / xs.size, 2)
).values.sum

// (b) AC × 2 + AB × 2 + CB × 2
val giniB = (xs: List[_]) => xs.groupBy(identity).mapValues( v =>
    v.size.toDouble / xs.size
).toList.combinations(2).map( x =>
    x.head._2 * x.last._2 * 2
).sum

val list = List("A", "B", "B", "C", "B", "A")

println( giniA(list) )
println( giniB(list) )
実行結果
> scala gini.scala

0.6111111111111112
0.611111111111111

Java 8 で実装

Java 8 の Stream API では groupingBycounting メソッドを組み合わせて collect すると要素毎のカウントを取得できます。

要素の組み合わせを取得するようなメソッドは無さそうだったので自前で実装しました。

下記で countBy(list) の結果は {A=2, B=3, C=1}
combination(countBy(list)) の結果は [[A=2, B=3], [A=2, C=1], [B=3, A=2], [B=3, C=1], [C=1, A=2], [C=1, B=3]]
のようになります。

Gini.java
import static java.util.stream.Collectors.*;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;

class Gini {
    public static void main(String... args) {
        List<String> list = Arrays.asList("A", "B", "B", "C", "B", "A");

        System.out.println( giniA(list) );
        System.out.println( giniB(list) );
    }

    // (a) 1 - (AA + BB + CC)
    private static double giniA(List<String> xs) {
        return 1 - countBy(xs).values().stream().mapToDouble( x -> Math.pow(x.doubleValue() / xs.size(), 2) ).sum();
    }

    // (b) AB + AC + BA + BC + CA + CB
    private static double giniB(List<String> xs) {
        return combination(countBy(xs)).stream().mapToDouble( s ->
            s.stream().mapToDouble( t -> 
                t.getValue().doubleValue() / xs.size()
            ).reduce(1.0, (a, b) -> a * b ) 
        ).sum();
    }

    private static <T> Map<T, Long> countBy(Collection<T> xs) {
        return xs.stream().collect(groupingBy(Function.identity(), counting()));
    }

    private static <T, S> Collection<? extends List<Map.Entry<T, S>>> combination(Map<T, S> data) {
        return data.entrySet().stream().flatMap( x ->
            data.entrySet().stream().flatMap ( y ->
                (x.getKey().equals(y.getKey()))? Stream.empty(): Stream.of(Arrays.asList(x, y))
            )
        ).collect(toList());
    }
}
実行結果
> java Gini

0.6111111111111112
0.611111111111111

Frege で実装

Frege の group 関数では連続した同じ要素をグルーピングしますので sort してから使う必要があります。

下記で、group . sort $ list の結果は [["A", "A"], ["B", "B", "B"], ["C"]] となります。

組み合わせ (AB, AC 等) の確率計算にはリスト内包表記を使ってみました。

gini.fr
package sample.Gini where

import frege.prelude.Math (**)
import Data.List

size = fromIntegral . length

-- (a) 1 - (AA + BB + CC)
giniA xs = (1 - ) . sum . map calc . group . sort $ xs
    where
        listSize = size xs
        calc x = (size x / listSize) ** 2

-- (b) AB + AC + BA + BC + CA + CB
giniB xs = fold (+) 0 . calcProb . map prob . group . sort $ xs
    where
        listSize = size xs
        prob ys = (head ys, size ys / listSize)
        calcProb zs = [ snd x * snd y | x <- zs, y <- zs, fst x /= fst y]

main args = do
    let list = ["A", "B", "B", "C", "B", "A"]

    println $ giniA list
    println $ giniB list
実行結果
> java -cp .;frege3.21.586-g026e8d7.jar sample.Gini

0.6111111111111112
0.611111111111111
runtime ・・・

備考

giniB 関数の fold (+) 0 の部分は sum でも問題ないように思うのですが 、sum を使うと下記のようなエラーが発生しました。

giniB 関数で sum を使った場合のエラー内容
E sample.fr:14: inferred type is more constrained than expected type
    inferred:  (Real t17561,Show t17561) => [String] -> IO ()
    expected:  [String] -> IO ()

ちなみに、ほぼ同じコードが Haskell で動作するのですが、Haskell の場合は sum を使っても問題ありませんでした。 (gini.hs 参照)