본문 바로가기
CS/머신러닝

XGBoost, LightGBM 사용과 하이퍼파라미터

by jaehoonChoi 2023. 2. 11.

XGBoost와 LightGBM은 엄청난 성능을 자랑하는 앙상블 알고리즘입니다. 

이 둘은 모두 사이킷런 전용 클래스가 있습니다. 사이킷런 전용 클래스는 fit과 predict, score등

직관적이고 일관화된 모델훈련&평가식을 제공하므로 사이킷런 전용 클래스를 사용합니다.

 

1. XGBoost의 주요 파라미터

 

n_estimators: 훈련횟수를 의미합니다.  훈련을 너무 많이 해도 과적합이 될 수 있습니다. 

 

learning_rate: 학습률을 의미합니다. 0~1사이 값이며, 주로 0.01에서 0.2정도사이에서 결정합니다.

다른 하이퍼 파라미터를 조정후 미세하게 조정하며 성능을 올리기도 합니다. 

 

min_child_weight: 트리에서 가지를 칠지 결정하는 가중치의 합입니다. 과적합을 제어하는

주요 하이퍼 파라미터입니다.

 

max_depth: 트리기반 알고리즘에서 핵심 파라미터입니다. 깊이제한이 클수록 과적합이 발생하므로,

적절히 작게 조절하고 앙상블을 통해 성능을 올리는게 주요목적입니다. 

 

colsample_bytree: GBM의 max_features와 유사한 파라미터로, 트리생성에 필요한 feature을 샘플링하는데

쓰입니다. feature가 매우 많은 경우 과적합을 막아주는 역할을 합니다.

 

 

2.  LightGBM 주요 파라미터

n_estimators, learning_rate, max_depth : xgboost와 동일

 

num_leaves: 하나의 트리가 가지는 최대 leaf 개수입니다. 주요 하이퍼파라미터입니다.

 

min_child_samples: 주요 하이퍼파라미터입니다. 큰 값으로 설정하면 트리가 깊어지는걸 방지합니다. 

 

 

3.  조기중단

조기중단은 계속 훈련함에도 성능개선이 보이지 않을 때 그냥 끝내는 것입니다. 

early_stopping_rounds 파라미터를 통해 반복횟수를 정의하고, 조기중단 평가지표인 eval_metric

정해줍니다. 그리고 성능평가를 수행하는 eval_set을 만들어줍니다. 

[(train data), (valid data)] 꼴로 결정해주면 됩니다. 

 

 

 

4. 훈련과 평가

from xgboost import XGBClassifier #import
# xgboost model 
xgb_model=XGBClassifier(n_estimators=400, learning_rate=0.05, max_depth=3, 
                       eval_metric='logloss')
# fit 
xgb_model.fit(X_train, y_train, verbose=True)
# predict 
print(xgb.score(X_test, y_test))

 

위와 같이 일반적인 사이킷런 훈련방식과 동일합니다. 주요 파라미터를 넣어주고 돌리면 됩니다. 

LightGBM도 마찬가지입니다.

from lightgbm import LGBMClassifier #import
# LightGBM model
lgbm= LGBMClassifier(n_estimators=400, learning_rate=0.05)
# 조기종료 훈련, 검증데이터셋
evals=[(X_tr, y_tr), (X_val, y_val)]
# fit  50번이상 성능개선 안되면 종료, 평가지표는 logloss 
lgbm.fit(X_tr, y_tr, early_stopping_rounds=50, eval_metric='logloss',
       eval_set=evals, verbose=True)
# 예측 값 
preds=lgbm.predict(X_test)

 

 5. 중요도 표시

xgboost와 lightbgm은 트리기반 알고리즘이므로 feature별 중요도도 출력해줍니다.

plot_importance 를 import하여 ax객체를 입력하여 사용하면 됩니다.

from lightgbm import plot_importance
f, ax=plt.subplots(figsize=(10, 12))
plot_importance(lgbm, ax=ax);

댓글