ジニ不純度の算出 - Groovy, Scala , Java 8, Frege
書籍 「集合知プログラミング」 の 「7章 決定木によるモデリング」 にあったジニ不純度(ジニ係数)の計算を下記の JVM 言語で関数言語的に実装してみました。
今回のソースは http://github.com/fits/try_samples/tree/master/blog/20140601/
はじめに
ジニ不純度の算出は、下記の (1) と (2) で異なるアイテムを取り出す確率を求める事になります。
- (1) ある集合から 1つアイテムを取り出す
- (2) 取り出したアイテムを戻して再度 1つアイテムを取り出す
例えば、["A", "B", "B", "C", "B", "A"]
のような集合の場合に A、B、C を取り出す確率はそれぞれ以下のようになります。
A = 2/6 = 1/3 B = 3/6 = 1/2 C = 1/6
ここで、ジニ不純度は以下のような 2通りの方法で算出できます。
(下記の XY
は (1) で X が出て (2) で Y が出る確率を表している)
(a) ジニ不純度 = 1 - (AA + BB + CC) = 1 - (1/3 × 1/3 + 1/2 × 1/2 + 1/6 × 1/6) = 11/18 = 0.61 (b) ジニ不純度 = AB + AC + BA + BC + CA + CB = 1/3 × 1/2 + 1/3 × 1/6 + ・・・ = 11/18 = 0.61
(a) の方がシンプルな実装になると思います。
Groovy で実装
それではそれぞれの言語で実装してみます。
Groovy では countBy
メソッドで要素毎のカウント値を簡単に取得できます。
異なる要素同士の組み合わせは、今回 nCopies
、combinations
、findAll
を使って取得しました。
下記で list.countBy {it}
の結果は [A:2, B:3, C:1]
、
nCopies(2, list.countBy { it })
の結果は [[A:2, B:3, C:1], [A:2, B:3, C:1]]
、
nCopies(2, list.countBy { it }).combinations()
の結果は [[A=2, A=2], [B=3, A=2], ・・・, [B=3, C=1], [C=1, C=1]]
となります。
gini.groovy
import static java.util.Collections.nCopies // (a) 1 - (AA + BB + CC) def giniA = { xs -> 1 - xs.countBy { it }*.value.sum { (it / xs.size()) ** 2 } } // (b) AB + AC + BA + BC + CA + CB def giniB = { xs -> nCopies(2, xs.countBy { it }).combinations().findAll { // 同じ要素同士の組み合わせを除外 it.first().key != it.last().key }.sum { (it.first().value / xs.size()) * (it.last().value / xs.size()) } } def list = ['A', 'B', 'B', 'C', 'B', 'A'] println giniA(list) println giniB(list)
実行結果
> groovy gini.groovy 0.61111111112222222222 0.61111111112222222222
Scala で実装
Scala では Groovy の countBy に該当するメソッドが無さそうだったので groupBy
を使いました。
List で combinations(2)
とすればリスト内要素の 2要素の組み合わせ (下記では AB、AC、BC の 3種類の組み合わせ) を取得できます。
下記で list.groupBy(identity)
の結果は Map(A -> List(A, A), C -> List(C), B -> List(B, B, B))
となります。
gini.scala
import scala.math.pow // (a) 1 - (AA + BB + CC) val giniA = (xs: List[_]) => 1 - xs.groupBy(identity).mapValues( v => pow(v.size.toDouble / xs.size, 2) ).values.sum // (b) AC × 2 + AB × 2 + CB × 2 val giniB = (xs: List[_]) => xs.groupBy(identity).mapValues( v => v.size.toDouble / xs.size ).toList.combinations(2).map( x => x.head._2 * x.last._2 * 2 ).sum val list = List("A", "B", "B", "C", "B", "A") println( giniA(list) ) println( giniB(list) )
実行結果
> scala gini.scala 0.6111111111111112 0.611111111111111
Java 8 で実装
Java 8 の Stream API では groupingBy
と counting
メソッドを組み合わせて collect
すると要素毎のカウントを取得できます。
要素の組み合わせを取得するようなメソッドは無さそうだったので自前で実装しました。
下記で countBy(list)
の結果は {A=2, B=3, C=1}
、
combination(countBy(list))
の結果は [[A=2, B=3], [A=2, C=1], [B=3, A=2], [B=3, C=1], [C=1, A=2], [C=1, B=3]]
のようになります。
Gini.java
import static java.util.stream.Collectors.*; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Stream; class Gini { public static void main(String... args) { List<String> list = Arrays.asList("A", "B", "B", "C", "B", "A"); System.out.println( giniA(list) ); System.out.println( giniB(list) ); } // (a) 1 - (AA + BB + CC) private static double giniA(List<String> xs) { return 1 - countBy(xs).values().stream().mapToDouble( x -> Math.pow(x.doubleValue() / xs.size(), 2) ).sum(); } // (b) AB + AC + BA + BC + CA + CB private static double giniB(List<String> xs) { return combination(countBy(xs)).stream().mapToDouble( s -> s.stream().mapToDouble( t -> t.getValue().doubleValue() / xs.size() ).reduce(1.0, (a, b) -> a * b ) ).sum(); } private static <T> Map<T, Long> countBy(Collection<T> xs) { return xs.stream().collect(groupingBy(Function.identity(), counting())); } private static <T, S> Collection<? extends List<Map.Entry<T, S>>> combination(Map<T, S> data) { return data.entrySet().stream().flatMap( x -> data.entrySet().stream().flatMap ( y -> (x.getKey().equals(y.getKey()))? Stream.empty(): Stream.of(Arrays.asList(x, y)) ) ).collect(toList()); } }
実行結果
> java Gini 0.6111111111111112 0.611111111111111
Frege で実装
Frege の group
関数では連続した同じ要素をグルーピングしますので sort
してから使う必要があります。
下記で、group . sort $ list
の結果は [["A", "A"], ["B", "B", "B"], ["C"]]
となります。
組み合わせ (AB, AC 等) の確率計算にはリスト内包表記を使ってみました。
gini.fr
package sample.Gini where import frege.prelude.Math (**) import Data.List size = fromIntegral . length -- (a) 1 - (AA + BB + CC) giniA xs = (1 - ) . sum . map calc . group . sort $ xs where listSize = size xs calc x = (size x / listSize) ** 2 -- (b) AB + AC + BA + BC + CA + CB giniB xs = fold (+) 0 . calcProb . map prob . group . sort $ xs where listSize = size xs prob ys = (head ys, size ys / listSize) calcProb zs = [ snd x * snd y | x <- zs, y <- zs, fst x /= fst y] main args = do let list = ["A", "B", "B", "C", "B", "A"] println $ giniA list println $ giniB list
実行結果
> java -cp .;frege3.21.586-g026e8d7.jar sample.Gini 0.6111111111111112 0.611111111111111 runtime ・・・
備考
giniB 関数の fold (+) 0
の部分は sum
でも問題ないように思うのですが 、sum を使うと下記のようなエラーが発生しました。
giniB 関数で sum を使った場合のエラー内容
E sample.fr:14: inferred type is more constrained than expected type inferred: (Real t17561,Show t17561) => [String] -> IO () expected: [String] -> IO ()
ちなみに、ほぼ同じコードが Haskell で動作するのですが、Haskell の場合は sum を使っても問題ありませんでした。 (gini.hs 参照)