LightGBM で多クラス分類する方法
🍪この記事の内容:
- LightGBM (Scikit-learn API) で多クラス分類する。
- このとき、デフォルトでは各サンプルを真クラスと予測する確率の対数の和を最小化するような分岐及び分岐後スコア割り当てが目指される。ただし、分岐位置の探索時にこの損失を厳密に計算するのではなく、その分岐でサンプルたちを分けてそれぞれスコアを調整した場合の損失改善幅を 2 次までのテイラー展開で近似し (split_gain)、split_gain が最大の分岐を選択する。
参考文献
- lightgbm.LGBMClassifier — LightGBM 4.6.0.99 documentation#lightgbm-lgbmclassifier, , 2025年12月25日参照.
- LightGBM: Fast Gradient Boosting with Leaf-wise Tree Growth - Complete Guide with Math Formulas & Python Implementation - Interactive | Michael Brenndoerfer | Michael Brenndoerfer, , 2025年12月25日参照.
- microsoft/LightGBM: A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks., , 2025年12月25日参照.
iris データセットの 3 クラス分類
Scikit-learn から iris データセットを取得し、うち 60% (90 サンプル) を訓練データに、40% (60 サンプル) を評価データ兼テストデータにして学習した例が以下である。なお、わかりやすさのために葉の枚数をあえて制限していることに留意。- iris データセットにはアヤメの萼片の長さと幅、花弁の長さと幅、品種 (セトーサ、バーシカラー、バージニカ) が 150 サンプル含まれ、品種分類タスク向けにデザインされている。
- LightGBM で 3 クラス分類をするときは 1 ラウンドごとに各クラスらしさを出す 3 本の木が学習される。以下では 55 ラウンド目以降 3 ラウンド連続で損失が改善せず学習が止まるので、出力された
iris_lgbm.txtをみると 165 本の木の分岐位置が記述されている。 - 出力されたレポートをみると、テストデータ中のセトーサは全て正しくセトーサに分類されているが、バージニカのうち 1 サンプルがバーシカラーに誤分類されていることがわかる。
- このモデルオブジェクトを保存し、リロードして予測すると当然だが同じ結果になる。
import lightgbm as lgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import pandas as pd
import joblib # モデルオブジェクト保存用 (再利用用)
# アイリスデータセットのロード
# https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html
data = load_iris(as_frame=True) # データフレームとして取得
assert type(data.data).__name__ == 'DataFrame'
assert type(data.target).__name__ == 'Series'
assert data.data.shape == (150, 4)
assert data.data.columns.to_list() == [
'sepal length (cm)', 'sepal width (cm)', # 萼片
'petal length (cm)', 'petal width (cm)', # 花弁
]
assert data.target.shape == (150,)
assert data.target.name == 'target'
assert list(data.target.unique()) == [0, 1, 2]
assert list(data.target_names) == ['setosa', 'versicolor', 'virginica']
# 特徴量名に空白があると怒られるので排除しておく
data.data.columns = [col.replace(' ', '') for col in data.data.columns]
# トレインテストスプリット
# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
x_train, x_test, y_train, y_test = train_test_split(
data.data,
data.target,
test_size=0.4,
random_state=42
)
assert x_train.shape == (90, 4)
assert x_test.shape == (60, 4)
assert type(x_train).__name__ == type(x_test).__name__ == 'DataFrame'
# LGBM モデルインスタンス作成
# https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html#lightgbm-lgbmclassifier
# https://lightgbm.readthedocs.io/en/latest/Parameters.html
model = lgb.LGBMClassifier(
objective='multiclass',
random_state=42,
num_leaves=4, # 1 本の木には最大 4 枚の葉
n_estimators=100, # 最大 100 ラウンド
learning_rate=0.1,
verbosity=-1,
)
# フィット
# https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html#lightgbm.LGBMClassifier.fit
model.fit(
x_train,
y_train,
eval_set=(x_test, y_test),
eval_metric='multi_logloss',
callbacks=[
lgb.early_stopping(stopping_rounds=3), # 3 ラウンド連続で損失が改善しなかったら停止
lgb.log_evaluation(1),
]
)
# テストデータの予測と結果レポート
pd.options.display.precision = 3 # 小数点以下 3 桁表示
def report(y_pred):
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html
d = classification_report(y_test, y_pred, target_names=data.target_names, output_dict=True)
print('accuracy: {:.3f}'.format(d['accuracy']))
del d['accuracy']
print(pd.DataFrame(d).T) # accuracy 以外の値は同じキーを持つ辞書
y_pred = model.predict(x_test)
report(y_pred)
# 学習結果の保存
model.booster_.save_model('iris_lgbm.txt') # 分岐位置情報テキストの保存 (解釈用)
joblib.dump(model, 'iris_lgbm.joblib') # オブジェクトごと保存 (リロード用)
# 学習結果のリロードとテストデータの予測 (再)
model_loaded = joblib.load('iris_lgbm.joblib')
y_pred_loaded = model_loaded.predict(x_test)
report(y_pred_loaded)
[1] valid_0's multi_logloss: 0.943794
Training until validation scores don't improve for 3 rounds
[2] valid_0's multi_logloss: 0.810079
[3] valid_0's multi_logloss: 0.702025
...
[54] valid_0's multi_logloss: 0.0361876
[55] valid_0's multi_logloss: 0.0357035
[56] valid_0's multi_logloss: 0.0357368
[57] valid_0's multi_logloss: 0.0368628
[58] valid_0's multi_logloss: 0.03578
Early stopping, best iteration is:
[55] valid_0's multi_logloss: 0.0357035
accuracy: 0.983
precision recall f1-score support
setosa 1.000 1.000 1.000 23.0
versicolor 0.950 1.000 0.974 19.0
virginica 1.000 0.944 0.971 18.0
macro avg 0.983 0.981 0.982 60.0
weighted avg 0.984 0.983 0.983 60.0
# リロードしても評価指標は同じなので略
[参考] LightGBM の多クラス分類の目的関数について
これ以降、わかりやすさのために 1 本の木の葉の枚数を 2 枚に絞り (1 回の分岐のみ)、学習率を 1 にし、1 ラウンドで学習を止める。このときの結果は以下である。recall をみると、セトーサは 23 サンプル中 1 サンプル、バーシカラーは 19 サンプル中 2 サンプル、バージニカは 18 サンプル中 1 サンプルの検出が漏れている。さらに presision をみると、検出が漏れたセトーサとバージニカはバーシカラーに、検出が漏れたバーシカラーはバージニカに誤分類されている。
[1] valid_0's multi_logloss: 0.347051
Training until validation scores don't improve for 3 rounds
Did not meet early stopping. Best iteration is:
[1] valid_0's multi_logloss: 0.347051
accuracy: 0.933
precision recall f1-score support
setosa 1.000 0.957 0.978 23.0
versicolor 0.895 0.895 0.895 19.0
virginica 0.895 0.944 0.919 18.0
macro avg 0.930 0.932 0.930 60.0
weighted avg 0.935 0.933 0.934 60.0
multi_logloss: 0.347051 (予測確率の対数損失) は以下のように検算できる。つまり、各サンプルの「真クラスである確率をいくらと予測したかの対数」の平均の -1 倍である。すべてのサンプルに対して真クラスである確率を 1 と予測すれば log(1) = 0 よりこれは 0 になる。
import numpy as np
df = pd.DataFrame(model.predict_proba(x_test))
df.columns = [f'p({s})' for s in data.target_names]
df['true'] = y_test.map(lambda x: data.target_names[x]).values
df['p(true)'] = df.apply(lambda row: row['p({})'.format(row['true'])], axis=1)
df['log(p(true))'] = df['p(true)'].map(np.log)
pd.options.display.precision = 5
print(df.head())
print(-df['log(p(true))'].mean()) # multi_logloss
p(setosa) p(versicolor) p(virginica) true p(true) log(p(true))
0 0.15550 0.66559 0.17891 versicolor 0.66559 -0.40708
1 0.91278 0.04108 0.04614 setosa 0.91278 -0.09126
2 0.05022 0.21497 0.73481 virginica 0.73481 -0.30814
3 0.15550 0.66559 0.17891 versicolor 0.66559 -0.40708
4 0.15550 0.66559 0.17891 versicolor 0.66559 -0.40708
0.34705110563457964
LightGBM はデフォルトで上記の「予測確率の対数損失」の最小化を目指す。しかし、分岐位置の探索時にこの損失の減少幅を厳密に計算しているのではなく、2 次までのテイラー展開で近似している [2]。例えば、今回の学習で 1 ラウンド目の 1 本目の木 (セトーサの分離を目指す木) は iris_lgbm.txt より以下である (この木の分岐は「花弁の長さが 1.8cm 以下か」というものだが、実際訓練データ 90 サンプルのうちこれを満たす 26 サンプルは全てセトーサであり、27 サンプルあったセトーサのほぼ全てが左の葉に振り分けられる)。
Tree=0
num_leaves=2
num_cat=0
split_feature=2 # 花びらの長さが
split_gain=56.875
threshold=1.8 # 1.8cm 未満
decision_type=2
left_child=-1
right_child=-2
leaf_value=1.0182493968717188 -2.1067506267809168
leaf_weight=8.1899999380111712 20.159999847412109
leaf_count=26 64
internal_value=-1.20397 # セトーサの割合 27 / 90 の対数に等しい
internal_weight=0
internal_count=90
is_linear=0
shrinkage=1
この木の分岐のゲイン split_gain=56.875 は損失の減少幅を近似したものであり、以下のように検算できる。陽に出てこないが 1 ラウンド目なので初期のセトーサ予測値は全員同じで、訓練データ内のセトーサの割合 27 / 90 の対数である。
# leaf_weight は leaf 内の各サンプルの損失関数のヘシアンの合計値
h_l = 8.1899999380111712 # 左組のヘシアンの合計
h_r = 20.159999847412109 # 右組のヘシアンの合計
h = h_l + h_r # 全員のヘシアンの合計
# leaf 内の各サンプルの勾配の合計を leaf_value = - g / h (ニュートン方向のようなもの) から逆算
# g = - h * leaf_value
leaf_value_l = 1.0182493968717188 # 左組ニュートン方向
leaf_value_r = -2.1067506267809168 # 右組ニュートン方向
g_l = - h_l * leaf_value_l # 左組の勾配の合計
g_r = - h_r * leaf_value_r # 右組の勾配の合計
g = g_l + g_r # 全員の勾配の合計
leaf_value = - g / h # 全員ニュートン方向
d_l = g_l * leaf_value_l # 左組たちの予測値を一斉に左組ニュートン方向だけ調整したときの損失変化量
d_r = g_r * leaf_value_r # 右組たちの予測値を一斉に右組ニュートン方向だけ調整したときの損失変化量
d = g * leaf_value # 全員の予測値を一斉に全員ニュートン方向だけ調整したときの損失変化量
print(-(d_l + d_r - d)) # 左組と右組を分けた方がこれだけ損失を減らせる = split_gain 56.875
また、上記のヘシアンの合計は以下のように検算できる。ただ、勾配の合計は合わない (ソースコードを参照してもなぜ合わないのかまだ解決できていない)。
import numpy as np
def softmax(x, axis=-1):
x = np.asarray(x)
x_max = np.max(x, axis=axis, keepdims=True)
e = np.exp(x - x_max)
return e / np.sum(e, axis=axis, keepdims=True)
p = softmax([ # 初期スコア及びそのソフトマックス
np.log(27 / 90),
np.log(31 / 90),
np.log(32 / 90),
])
print( # ヘシアンの値 (3 クラス分類のときファクター 3 / 2 をかける [3])
(3 / 2) * 26 * p[0] * (1 - p[0]), # 左組のヘシアンの合計
(3 / 2) * 64 * p[0] * (1 - p[0]), # 右組のヘシアンの合計
(3 / 2) * 90 * p[0] * (1 - p[0]), # 全員のヘシアンの合計
) # 8.19 20.16 28.35 (これは合っている)
print( # 勾配の値
26 * (p[0] - 1), # 左組の勾配の合計
(p[0] - 1) + 63 * p[0], # 右組の勾配の合計
27 * (p[0] - 1) + 63 * p[0], # 全員の勾配の合計
) # -18.2 18.2 0 (これは合っていない)
なお、実際には学習率を 1 未満にするので、それが leaf_value に乗じられることに留意。正則化を適用した場合にも leaf_value は変化する。