ジニ不純度の算出2 - Ruby, C#, F#, Erlang

前回 に続き、今回は下記のようなプログラム言語でジニ不純度(ジニ係数)の算出処理を同じように実装してみました。

今回のソースは http://github.com/fits/try_samples/tree/master/blog/20140608/

Ruby で実装

Ruby では group_by で要素毎の Hash オブジェクトを取得できます。
(下記では {"A"=>["A", "A"], "B"=>["B", "B", "B"], "C"=>["C"]}

なお、Hash で map した結果は配列になります。
(下記 list.group_by {|x| x }.map {|k, v| [k, v.size.to_f / xs.size] } の結果は [["A", 0.33・・・], ["B", 0.5], ["C", 0.16・・・]]

また、combination(2) で前回の Scala の関数と同様に 2要素の組み合わせを取得できます。
(下記では、[["A", 0.33・・・], ["B", 0.5]], [["A", 0.33・・・], ["C", 0.16・・・]], [["B", 0.5], ["C", 0.16・・・]]

gini.rb
#coding:utf-8

# (a) 1 - (AA + BB + CC)
def giniA(xs)
    1 - xs.group_by {|x| x }.inject(0) {|a, (k, v)| a + (v.size.to_f / xs.size) ** 2 }
end

# (b) AB × 2 + AC × 2 +  BC × 2
def giniB(xs)
    xs.group_by {|x| x }.map {|k, v| [k, v.size.to_f / xs.size] }.combination(2).inject(0) {|a, t| a + t.first.last * t.last.last * 2}
end

list = ["A", "B", "B", "C", "B", "A"]

puts giniA(list)
puts giniB(list)
実行結果
> ruby gini.rb

0.6111111111111112
0.611111111111111

C# で実装

LINQGroupBy メソッドを使えば要素毎にグルーピングした IGrouping<TKey, TSource> のコレクションを取得できます。

要素の組み合わせも LINQ のクエリ式を使えば簡単に作成できます。(下記の combination メソッド

gini.cs
using System;
using System.Collections.Generic;
using System.Linq;

class Gini
{
    public static void Main(string[] args)
    {
        var list = new List<string>() {"A", "B", "B", "C", "B", "A"};

        Console.WriteLine("{0}", giniA(list));
        Console.WriteLine("{0}", giniB(list));
    }

    // (a) 1 - (AA + BB + CC)
    private static double giniA<K>(IEnumerable<K> xs)
    {
         return 1 - xs.GroupBy(x => x).Select(x => Math.Pow((double)x.Count() / xs.Count(), 2)).Sum();
    }

    // (b) AB + AC + BA + BC + CA + CB
    private static double giniB<K>(IEnumerable<K> xs)
    {
        return
            combination(
                countBy(xs).Select(t =>
                    Tuple.Create(t.Item1, (double)t.Item2 / xs.Count())
                )
            ).Select(x => x.Item1.Item2 * x.Item2.Item2).Sum();
    }

    private static IEnumerable<Tuple<K, int>> countBy<K>(IEnumerable<K> xs) {
        return xs.GroupBy(x => x).Select(g => Tuple.Create(g.Key, g.Count()));
    }

    // 異なる要素の組み合わせを作成
    private static IEnumerable<Tuple<Tuple<K, V>, Tuple<K, V>>> combination<K, V>(IEnumerable<Tuple<K, V>> data) {
        return
            from x in data
            from y in data
            where !x.Item1.Equals(y.Item1)
            select Tuple.Create(x, y);
    }
} 
実行結果
> csc gini.cs
> gini.exe

0.611111111111111
0.611111111111111

F# で実装

  • F# 3.1

F# では Seq.countBy で要素毎のカウント値を取得できます。
(下記では seq [("A", 2); ("B", 3); ("C", 1)]

要素の組み合わせは内包表記を使えば簡単に作成できます。 (下記の combinationCount)

gini.fs
let size xs = xs |> Seq.length |> float

// (a) 1 - (AA + BB + CC)
let giniA xs = xs |> Seq.countBy id |> Seq.sumBy (fun (k, v) -> (float v / size xs) ** 2.0) |> (-) 1.0

let combinationCount cs = [
    for x in cs do
        for y in cs do
            if fst x <> fst y then
                yield (snd x, snd y)
]

// (b) AB + AC + BA + BC + CA + CB
let giniB xs = xs |> Seq.countBy id |> combinationCount |> Seq.sumBy (fun (x, y) -> (float x / size xs) * (float y / size xs))

let list = ["A"; "B"; "B"; "C"; "B"; "A";]

printfn "%A" (giniA list)
printfn "%A" (giniB list)
実行結果
> fsc gini.fs
> gini.exe

0.6111111111
0.6111111111

Erlang で実装

Erlang ではグルーピング処理等は用意されていないようなので自前で実装しました。 (今回は dict モジュールを使いました)

リスト内包表記 ([<構築子> || <限定子>, ・・・]) で使用するジェネレータ (<変数> <- <式>) の右辺はリストになる式を指定する必要があるため、dict:to_list() でリスト化しています。

gini.erl
-module(gini).
-export([main/1]).

groupBy(Xs) -> lists:foldr(fun(X, Acc) -> dict:append(X, X, Acc) end, dict:new(), Xs).

countBy(Xs) -> dict:map( fun(_, V) -> length(V) end, groupBy(Xs) ).

% (a) 1 - (AA + BB + CC)
giniA(Xs) -> 1 - lists:sum([ math:pow(V / length(Xs), 2) || {_, V} <- dict:to_list(countBy(Xs)) ]).

combinationProb(Xs) -> [ {Vx, Vy} || {Kx, Vx} <- Xs, {Ky, Vy} <- Xs, Kx /= Ky ].

% (b) AB + AC + BA + BC + CA + CB
giniB(Xs) -> lists:sum([ (Vx / length(Xs)) * (Vy / length(Xs)) || {Vx, Vy} <- combinationProb(dict:to_list(countBy(Xs))) ]).

main(_) ->
    List = ["A", "B", "B", "C", "B", "A"],

    io:format("~p~n", [ giniA(List) ]),
    io:format("~p~n", [ giniB(List) ]).
実行結果
> escript gini.erl

0.6111111111111112
0.611111111111111

なお、groupBy(List)countBy(List) の結果を出力すると下記のようになりました。

groupBy(List) の出力結果
{dict,3,16,16,8,80,48,
      {[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]},
      {{[],
        [["B","B","B","B"]],
        [],[],[],[],
        [["C","C"]],
        [],[],[],[],[],
        [["A","A","A"]],
        [],[],[]}}}
countBy(List) の出力結果
{dict,3,16,16,8,80,48,
      {[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]},
      {{[],
        [["B"|3]],
        [],[],[],[],
        [["C"|1]],
        [],[],[],[],[],
        [["A"|2]],
        [],[],[]}}}