Pandas

Pythonで機械学習を学ぶ ハイパーパラメータチューニング グリッドサーチ

汎化性能を向上するために、ハイパーパラメータ(学習器に人が設定する引数、例えば決定木の深さなど)のチューニング手法の一つグリッドサーチを学びます。scikit-learnでは関数が用意されており簡単に活用できます。

グリッドサーチ(GridSearchCV)

以前に作成したSVMに適用してベストモデルを構築します。

import numpy as np
import pandas as pd
import requests
import io
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV

# 読み込み
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/forest-fires/forestfires.csv'
res = requests.get(url).content
data = pd.read_csv(io.StringIO(res.decode('utf-8')),
                   header=None, skiprows=1)
data.columns = ['X', 'Y', 'month', 'day', 'FFMC', 'DMC',
                'DC', 'ISI', 'temp', 'RH', 'wind', 'rain', 'area']

# 確認
print('データ形式:{}'.format(data.shape))
print(data.head())
# print(data.dtypes)

# 10以上のときにTureにする
data['flg'] = data['area'].map(lambda x: 1 if x >= 10 else 0)

# 目的変数、説明変数
X = data[['FFMC', 'DMC', 'DC', 'ISI', 'temp', 'wind', 'rain']]
y = data['flg']

# 訓練データとテストデータの準備
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=0)

# 標準化
sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)

# GridSearchで設定するパラメータを準備
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100],
              'kernel': ['rbf', 'linear', 'poly'],
              'degree': np.arange(1, 6, 1),
              'gamma': np.linspace(0.01, 1.0, 20)
              }

# GridSearch、k分割交差検証
gs = GridSearchCV(estimator=SVC(max_iter=10000),
                  param_grid=param_grid,
                  cv=5)

# ベストモデルの構築
gs.fit(X_train_std, y_train)

# 結果表示
print('Best score:{:.3f}'.format(gs.best_score_))
print('Best parameters:{}'.format(gs.best_params_))
print('Test score:{:.3f}'.format(gs.score(X_test_std, y_test)))

実行結果

Best score:0.826
Best parameters:{'C': 0.001, 'degree': 1, 'gamma': 0.01, 'kernel': 'rbf'}
Test score:0.779

# グリッドサーチをせずに、デフォルトで学習させたときの実行結果
train:0.758
test:0.740