ガウス混合分布のパラメータをscikit-learnで推定する

scikit-learnでガウス混合分布のパラメータをさくっと推定する方法がありましたので、その備忘録です。

ガウス混合分布

ガウス混合分布は、複数のガウス分布を線形結合した分布で、以下式で表されます。


p(\vec{x})=\sum_i^N w_i N(\vec{x} | \mu_i, \sigma^2_i)

  • N: ガウス分布数 (ハイパパラメータ)
  •  w_i : ガウス分布の重み ( \sum_i^N w_i=1)

パラメータは  w_i, \mu_i, \sigma^2_i で、3×N個となります。

音声認識などでは、このガウス混合分布モデルと隠れマルコフモデルと組み合わせた手法が使われています。

scikit-learnのGaussianMixture

ガウス混合分布はEMアルゴリズムでパラメータを推定することになりますが、scikit-learnではガウス混合分布のパラメータ推定を行うツールとしてGaussianMixtureが提供されています。

irisデータセットの萼片長 (sepal length) を使って、GaussianMixtureによるパラメータ推定を実装していきます。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import seaborn as sns

from sklearn.mixture import GaussianMixture
from sklearn.datasets import load_iris

# irisデータセットのロード
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target

# 種類 (ラベル) によって、サンプル数を変えます
d1 = df[df['label'] == 0].sample(30)  # setosa
d2 = df[df['label'] == 1].sample(50)  # versicolor
d3 = df[df['label'] == 2].sample(40)  # virginica

# 萼片長のデータのみを使う
X = pd.concat([d1['sepal length (cm)'], d2['sepal length (cm)'], d3['sepal length (cm)']])
Y = pd.concat([d1['label'], d2['label'], d3['label']])

# ヒストグラム
plt.hist([X[Y==0], X[Y==1], X[Y==2]], bins=np.arange(X.min(), X.max(), 0.2), stacked=True, label=iris.target_names)
plt.legend()
plt.plot()

種類 (0: setosa, 1: versicolor, 2: virginica) によって萼片長の分布は異なっています。
足し合わさっているので、全体的には5.0, 5.6, 6.2付近をピークとした多峰性の分布となっていますが、いずれも混じりあっているため、境界によってきれいに分離できません。

f:id:ohke:20190608173834p:plain

それではGaussianMixtureで学習させます。

  • n_componentsは分布数で、3を設定
  • covariance_typeで分散の種類を選択でき、"spherical"とすると各分布は単一の分散になります
  • 推定されたパラメータは、weights_, means_, covariance_ にそれぞれ入ってます
# GaussianMixtureの学習
gmm = GaussianMixture(
    n_components=3,
    covariance_type='spherical'
).fit(
    np.array(X).reshape(-1, 1)  # 次元数2を入力とするため変形
)

# 重み
print(gmm.weights_)
# [0.309359   0.43752389 0.25311711]

# 期待値
print(gmm.means_)
# [[5.80722255]
#  [6.57885203]
#  [4.93215112]]

# 分散
print(gmm.covariances_)
# [0.10564111 0.32196481 0.06607687]

この3つの分布を描画します。

x = np.linspace(3, 9, 600)

gd1 = mlab.normpdf(x, gmm.means_[0, -1], np.sqrt(gmm.covariances_[0]))
gd2 = mlab.normpdf(x, gmm.means_[1, -1], np.sqrt(gmm.covariances_[1]))
gd3 = mlab.normpdf(x, gmm.means_[2, -1], np.sqrt(gmm.covariances_[2]))
    
plt.plot(x, gmm.weights_[0] * gd1, label='gd1')
plt.plot(x, gmm.weights_[1] * gd2, label='gd2')
plt.plot(x, gmm.weights_[2] * gd3, label='gd3')
plt.legend()
plt.show()

期待値4.9, 5.8, 6.6の3つの正規分布が描かれています。またgb1 (versicolor) とgb2 (virginica) は重なりも大きいことがわかります。

f:id:ohke:20190608181445p:plain

ある値xに対して、それがどの分布が占める割合が高いかをpredictメソッドで計算することができ、これにより (ソフト) クラスタリングができます。
この例の場合、71.7%の精度で正しく判定できました。

# 属する分布を予測
Y_predict = gmm.predict(np.array(X).reshape(-1, 1))
print(Y_predict)
# [2 2 0 2 2 2 2 2 2 2 2 2 2 2 0 0 2 0 2 2 2 0 2 2 2 2 2 2 2 2 0 0 2 1 0 1 0
#  0 0 0 0 2 0 0 1 0 2 1 0 1 0 1 0 1 0 0 0 1 0 0 1 2 0 0 0 1 1 0 0 1 1 0 0 0
#  0 0 1 1 2 1 1 1 0 1 1 1 1 1 0 1 0 0 0 1 1 1 1 1 0 2 1 1 1 1 1 1 1 1 0 1 1
#  1 1 1 1 1 1 1 1 1]

# ラベルを 0->2, 1->0, 2->1 へ置き換える
Y_new = Y.copy()
Y_new[Y==0] = 2
Y_new[Y==1] = 0
Y_new[Y==2] = 1

# 精度を計算
print(sum(Y_new == Y_predict) / len(Y_new))
# 0.7166666666666667