Python 回帰木でセッション数を予測するモデルを作成する

前回の投稿では線形回帰を使ってセッション数を予測しましたが、今回は回帰木を使ってみます。

Python GoogleAnalyticsのデータを使って線形回帰でセッション数を予測するモデルを作る - け日記

回帰木による学習・テスト

前回の投稿では、本ブログの1日あたりのセッション数(来訪回数)を予測するために、2つの説明変数(4/1を0として何日目かを表すnth、休日を表すholiday)を2次の多項式に拡張して、線形回帰で学習・テストさせました。

今回使う回帰木は、説明変数に閾値を設けて学習サンプルを分割していき(多くの実装は二分割)、リーフでの目的変数の平均値をそのリーフに分類されたサンプルの予測値とするものです。

散布図を見てみますと、概ね4つのグループ(四角枠)に分けられそうです。 回帰木で学習すると、例えばholidayが0ならばAグループ、0ならばBグループ、Aグループの内nthが40以下ならA-1グループ、41以上ならA-2グループ、・・・というように、今ある説明変数だけでもざっくりと分割できると予想されます。

scikit-learnを使った回帰木の学習では、sklearn.tree.DecisionTreeRegressorを使います。

  • 学習データとテストデータを3:1で分割しています
  • 回帰木の深さが深いほど過学習を起こしやすくなるため、最大深さ3としています
    • 4グループに分割できればいいので、二分木であれば深さ2までで良さそうですが、おそらく最初にholidayを使って平日(上部3グループ)と休日(下部1グループ)に分けられるので、平日の3グループをさらに分割するためには深さ2では足りないと予想されたため、3にしました
  • 決定木は閾値による分類のため、説明変数の正規化・標準化を必要ありません
from sklearn.tree import DecisionTreeRegressor

# features_dfの取得については前回の記事を参照
X = features_df[['nth', 'holiday']]
y = features_df['sessions']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 最大深さ3の回帰木を作成
regressor = DecisionTreeRegressor(max_depth=3)

# 学習・テスト
regressor.fit(X_train, y_train)
print('Train score: {:.3f}'.format(regressor.score(X_train, y_train)))
print('Test score: {:.3f}'.format(regressor.score(X_test, y_test)))
# Train score: 0.949
# Test score: 0.862

# プロット
plt.scatter(X['nth'], y)
plt.plot(X['nth'], pipeline.predict(X))
plt.show()

結果として決定係数R2は、学習データで0.949、テストデータで0.862となり、前回の線形回帰を使ったモデル(学習データで0.916、テストデータで0.832)よりも少し改善されました。 また、プロットした図でも、概ね最初に予想した4グループに分割されていることが確認できます。

ただし、回帰木では長期的な増加傾向を反映できていないので、7月以降のセッション数をこれを使って予測すると、おそらく線形回帰よりも性能は悪くなると予想されます。

f:id:ohke:20170720081901p:plain

回帰木を描画

最後に、学習で獲得した回帰木を描画し、どういった条件でサンプルを分割させているのかを見てみます。 回帰木の描画ではdotファイルを出力し、graphvizで画像ファイル(png)に変換します。

Graphviz | Graphviz - Graph Visualization Software

まずgraphvizをインストールしておきます。

$ brew install graphviz

Pythonからはsklearn.tree.export_graphvizをインポートして、export_graphvizでdotファイルを出力します。

from sklearn.tree import export_graphviz

export_graphviz(regressor, out_file='tree.dot', feature_names=['nth', 'holiday'])

最後に、graphvizのdotコマンドでpngファイルへ出力します。

$ dot -Tpng tree.dot -o tree.png

以下のような回帰木が図として出力されます。 valueがそのグループの平均値で、mseが平均二乗誤差です。

当初の見立て通り、最初にholidayを閾値として平日と休日のグループに分け、以降はnthの値を閾値として分割していることがわかります。 例えば、4/3(nth=3, holiday=0)のセッション数は、一番左下のノードに分類されて74.3と予測されます。

f:id:ohke:20170720082249p:plain