決定木にて分類の基準によくジニ係数(Gini inpurity / Gini index)という尺度が使われる。
このジニ係数について少し考察してみたのでメモ。
データセット t の中から無作為に2回取り出すベルヌーイ試行*1を考えたときに、同じクラスのサンプルが取り出される確率は \sum_{i=1}^{K}P^2(C_i|t) 。その逆で異なるクラスのサンプルが取り出される確率は、 1- \sum_{i=1}^{K}P^2(C_i|t)となり、これがジニ係数の定義と一致するというわけ。
ここのジニ係数を計算する関数を参考にした。ただしジニ係数の定義上、ジニ係数の計算時に存在するクラスを事前に知る必要がないはず*3なのでデータセットだけを関数に渡す方式に変更している。
(*1): サンプルを複数回無作為に取り出す際に、一度取り出したサンプルを「戻して」再度サンプルを無作為に取り出す試行。
(*2): 初めてのパターン認識の(11.11)式。
(*3): たとえ対象のデータセットの中に現れないクラスが存在したとしてもそれはジニ係数に寄与しない。
このジニ係数について少し考察してみたのでメモ。
ジニ係数の定義と挙動
あるデータセットtの中にK種のクラスのサンプルが含まれる場合、ジニ係数は
I(t)= \sum_{i\neq j} P(C_i|t)P(C_j|t)=\sum_{i=1}^{K} P(C_i|t)(1-P(C_i|t) = 1- \sum_{i=1}^{K}P^2(C_i|t) と定義される。 ここでP(C_i|t)はデータセット t の中に含まれるクラス i のサンプル数の割合である。
この定義式を元に2つのクラス(0, 1)が含まれる20個のデータを例にクラス0が含まれる数に応じたジニ係数の変化をプロットしたのが下図。期待通りデータセットの中に各クラスのデータが均等に含まれれば含まれるほどジニ係数の値は大きくなり、偏って含まれれる場合は小さくなり、純粋に1つのクラスしか含まれない場合はゼロになる。
I(t)= \sum_{i\neq j} P(C_i|t)P(C_j|t)=\sum_{i=1}^{K} P(C_i|t)(1-P(C_i|t) = 1- \sum_{i=1}^{K}P^2(C_i|t) と定義される。 ここでP(C_i|t)はデータセット t の中に含まれるクラス i のサンプル数の割合である。
この定義式を元に2つのクラス(0, 1)が含まれる20個のデータを例にクラス0が含まれる数に応じたジニ係数の変化をプロットしたのが下図。期待通りデータセットの中に各クラスのデータが均等に含まれれば含まれるほどジニ係数の値は大きくなり、偏って含まれれる場合は小さくなり、純粋に1つのクラスしか含まれない場合はゼロになる。
ジニ係数の解釈
ジニ係数の定義は「データセットの中から無作為に2つサンプルを取り出したときに異なるクラスのサンプルが取り出される確率」と解釈するとわかりやすい。データセット t の中から無作為に2回取り出すベルヌーイ試行*1を考えたときに、同じクラスのサンプルが取り出される確率は \sum_{i=1}^{K}P^2(C_i|t) 。その逆で異なるクラスのサンプルが取り出される確率は、 1- \sum_{i=1}^{K}P^2(C_i|t)となり、これがジニ係数の定義と一致するというわけ。
複数データセットでのジニ係数
決定木アルゴリズムでは決定木のノードで元のデータセットを2つ(LとR)に分割する。そのためこの分割後の2つのデータセット全体の不純度が小さい分割の仕方を選ばないといけない。そのため決定木アルゴリズムでは、 I_{split}(t)= p_{L}I(t_{L}) + p_{R}I(t_{R}) の量が最小となる分割方法を見つけることになる*2。ここでp_{L}とp_{R}はそれぞれ分割後のデータセットのサンプル数の元のデータセットに対する割合。ジニ係数と交差エントロピー
はじめてのパターン認識でも書かれているけど、不純度の尺度はジニ係数だけではなく、交差エントロピー I(t)= -\sum_{i=1}^{K}P(C_i|t)\log P(C_i|t) も使われることがある。ただし、挙動がほぼ同じということと、ジニ係数であれば、計算コストの高いLogの計算をしなくて良いため、ジニ係数が利用されることが多い様子。ジニ係数の実装@Python3
折角なのでPythonでジニ係数を計算する関数をPythonで実装してみた。上記 I_{split}(t)を計算するコードです。ここのジニ係数を計算する関数を参考にした。ただしジニ係数の定義上、ジニ係数の計算時に存在するクラスを事前に知る必要がないはず*3なのでデータセットだけを関数に渡す方式に変更している。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# ジニ係数の計算 | |
def gini_impurity(datasets): | |
data_all = np.concatenate(datasets, axis=0) # データセットを結合 | |
n_all = len(data_all) # 全サンプル数 | |
class_set = set(data_all) #データセットに含まれているユニークなクラスのセットを取り出し | |
if(len(class_set) == 1): # クラスが1つしか含まれてなければ計算するまでもなくgini係数は0 | |
return 0.0 | |
gini = 0.0 | |
for dataset in datasets: | |
size = len(dataset) | |
# 分割後のデータセットの要素数がゼロならスキップ(空要素はジニ係数には影響しない) | |
if size == 0: | |
continue | |
score = 0.0 | |
for class_val in class_set: | |
p = np.sum(dataset == class_val) / size # class_valに一致する要素の数を全体数で割る。 | |
score += p * p | |
gini += (1.0 - score) * (float(size) / float(n_all)) # はじめてのパターン認識(11.11)式の後半 | |
return gini |
(*2): 初めてのパターン認識の(11.11)式。
(*3): たとえ対象のデータセットの中に現れないクラスが存在したとしてもそれはジニ係数に寄与しない。
0 件のコメント:
コメントを投稿