XGBoost는 최적화된 그레디언트 부스팅 구현이 가능한 파이썬 라이브러리이다. XGBoost는 빠른 속도, 확장성, 이식성이 특징이며 캐글 등 머신러닝 경연 대회에서 우승 후보들이 사용하는 도구로 성능이 아주 좋다.
이전 모델이 과소적합한 샘플에 가중치를 줘서 다음 모델에서 예측 정확도를 높이는 방식으로 모델을 보완해가는 부스팅 기법을 사용한다.
gmb보다 빠르고 조기종료가 가능하며 과적합 방지가 가능하다. 또한 분류와 회귀 둘 다 사용이 가능하다. 사이킷런에서 제공하지 않기 때문에 따로 설치가 필요하다.
pip install xgboost
위 코드로 간단하게 설치 가능하다.
XGBoost는 파이썬래퍼 XGBoost와 사이킷런래퍼XGBoost 두가지 방법으로 사용 가능한데, 파이썬래퍼를 사용할 경우 train, test data를 위해 별도의 DMatrix를 생성해야 한다.
*DMatrix: 넘파이 입력 파라미터를 받아서 만들어지는 XGBoost만의 데이터셋
주요 파라미터
-Bosster Params (모델 조건 설정)
파이썬래퍼 | 사이킷런래퍼 | 설명 |
eta (0.3) | learning_rate (0.1) | 학습 단계별로 가중치를 얼마나 사용할 것인지 결정 값이 작을수록 다음 단계의 결과물 적게 반영 (일반적으로 0.01~0.2사이에서 결정) |
num_boost_around (10) | n_estimators (100) | 생성할 weak learner의 수 |
min_child_weight (1) | min_child_weight (1) | 관측치에 대한 가중치합의 최소 과적합 조절 용도 |
gamma (0) | min_split_loss (0) | 리프노드의 추가 분할을 결정할 최소손실 감소값 값이 클수록 과적합 감소 |
max_depth (6) | max_depth (3) | 트리 깊이 |
sub_sample (1) | subsample (1) | 데이터 샘플링 비율 지정 일반적으로 0.5~1 사이 값 사용 |
colsample_bytree (1) | colsample_bytree (1) | 각 트리마다 피쳐 샘플링 비율 일반적으로 0.5~1 사이 값 사용 |
lambda (1) | reg_lambda (1) | L2규제(릿지) 가중치 클수록 과적합 감소 |
alpha (0) | reg_alpah (0) | L1규제(라쏘) 가중치 클수록 과적합 감소 |
scale_pos_weight (1) | scale_pos_weight (1) | 불균형 데이터셋의 균형 유지를 위해 사용 |
과적합 방지를 위한 파라미터 조정 방법
-eta 낮추기 (반대로 num_boost_round/n_estimators는 높여주기)
-max_depth 낮추기
-min_child_weight 높이기
-gamma 높이기
-subs_ample, colsample_bytree 낮추기
-Learning Task Params (학습 수행 시 객체함수, 평가 지표 설정)
objective | -reg:linear : 회귀 (default) -binary:logistict : 이진분류 -multi:softmax : 다중분류, 클래스 반환 -multi:softprob : 다중분류, 확률 반환 |
eval_metric | -목적함수에 따라 디폴트 값이 다름 (회귀 분석 : rmse, 클래스 분류: error) -rmse -mae -logloss -error -merror -mlogloss -auc |
early_stopooing_rounds로 조기종료 기능을 수행하는데 이때 반드시 eval_set과 eval_metric으 같이 설정해줘야 한다.
*eval_set : 성능 평가를 위한 평가용 데이터셋
*eval_metric : 평가 셋에 적용할 성능 평가 방법
(반복할 때마다 eval_set에 eval_metirc의 평가지표로 예측 오류를 측정)
파이썬 래퍼 코드 실습
import xgboost as xgb
from xgboost import plot_importance ## Feature Importance를 불러오기 위함
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
data = load_breast_cancer()
X_features = data.data
y_label = data.target
cancer = pd.DataFrame(data=X_features, columns = data.feature_names)
cancer['target'] = y_label
print(cancer.head())
print(cancer['target'].value_counts())
>>>1 357
0 212
Name: target, dtype: int64
1은 양성, 0은 음성으로 이진분류 모델이다.
앞서 말한 것처럼 파이썬래퍼로 모델을 학습시킬 경우 DMatrix 형태로 train, val, test set을 변환시켜줘야 한다.
#train,val,test split
X_train, X_test, y_train, y_test = train_test_split(X_features, y_label, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)
#DMatrix형태로 바꿔주기
dtrain = xgb.DMatrix(data=X_train, label=y_train)
dval = xgb.DMatrix(data=X_val, label=y_val)
dtest = xgb.DMatrix(data=X_test)
#모델 생성:예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
#오류 성능평가는 logloss
# num_boost_round 만큼 반복하는데 early_stopping_rounds 만큼 성능 향상이 없으면 중단
# early_stopping_rounds를 사용하려면 반드시 eval_set과 eval_metric이 함께 설정
params = {'max_depth' : 3,
'eta' : 0.1,
'objective' : 'binary:logistic',
'eval_metric' : 'logloss',
'early_stoppings' : 100 }
#훈련 set은 train, 평가set은 eval로 명기
xgb_model = xgb.train(params = params, dtrain = dtrain, num_boost_round = 400,
early_stopping_rounds = 100, evals=[(dtrain,'train'),(dval,'eval')])
>>>
[0] train-logloss:0.61110 eval-logloss:0.61392
[1] train-logloss:0.54569 eval-logloss:0.54705
[2] train-logloss:0.48914 eval-logloss:0.49040
[3] train-logloss:0.43928 eval-logloss:0.44400
[4] train-logloss:0.39793 eval-logloss:0.40186
[5] train-logloss:0.36053 eval-logloss:0.36743
[6] train-logloss:0.32755 eval-logloss:0.33466
[7] train-logloss:0.30003 eval-logloss:0.30660
[8] train-logloss:0.27555 eval-logloss:0.28448
[9] train-logloss:0.25406 eval-logloss:0.26223
[10] train-logloss:0.23388 eval-logloss:0.24419
[11] train-logloss:0.21688 eval-logloss:0.22693
[12] train-logloss:0.20016 eval-logloss:0.21070
[13] train-logloss:0.18490 eval-logloss:0.19531
[14] train-logloss:0.17197 eval-logloss:0.18232
[15] train-logloss:0.15976 eval-logloss:0.17115
[16] train-logloss:0.14969 eval-logloss:0.16147
[17] train-logloss:0.14022 eval-logloss:0.15193
[18] train-logloss:0.13169 eval-logloss:0.14399
[19] train-logloss:0.12290 eval-logloss:0.13664
[20] train-logloss:0.11589 eval-logloss:0.12951
[21] train-logloss:0.10925 eval-logloss:0.12318
[22] train-logloss:0.10244 eval-logloss:0.11784
[23] train-logloss:0.09700 eval-logloss:0.11318
[24] train-logloss:0.09153 eval-logloss:0.10907
[25] train-logloss:0.08691 eval-logloss:0.10518
[26] train-logloss:0.08220 eval-logloss:0.10217
[27] train-logloss:0.07829 eval-logloss:0.09938
[28] train-logloss:0.07394 eval-logloss:0.09597
[29] train-logloss:0.07034 eval-logloss:0.09340
[30] train-logloss:0.06734 eval-logloss:0.09138
[31] train-logloss:0.06367 eval-logloss:0.08916
[32] train-logloss:0.06075 eval-logloss:0.08714
[33] train-logloss:0.05812 eval-logloss:0.08535
[34] train-logloss:0.05554 eval-logloss:0.08305
[35] train-logloss:0.05271 eval-logloss:0.08174
[36] train-logloss:0.05067 eval-logloss:0.07960
[37] train-logloss:0.04836 eval-logloss:0.07736
[38] train-logloss:0.04611 eval-logloss:0.07662
[39] train-logloss:0.04443 eval-logloss:0.07508
[40] train-logloss:0.04253 eval-logloss:0.07395
[41] train-logloss:0.04079 eval-logloss:0.07352
[42] train-logloss:0.03943 eval-logloss:0.07205
[43] train-logloss:0.03793 eval-logloss:0.07180
[44] train-logloss:0.03653 eval-logloss:0.07093
[45] train-logloss:0.03526 eval-logloss:0.07048
[46] train-logloss:0.03406 eval-logloss:0.06987
[47] train-logloss:0.03284 eval-logloss:0.06952
[48] train-logloss:0.03189 eval-logloss:0.06944
[49] train-logloss:0.03105 eval-logloss:0.06965
[50] train-logloss:0.03018 eval-logloss:0.06890
[51] train-logloss:0.02931 eval-logloss:0.06899
[52] train-logloss:0.02829 eval-logloss:0.06883
[53] train-logloss:0.02746 eval-logloss:0.06835
[54] train-logloss:0.02672 eval-logloss:0.06752
[55] train-logloss:0.02589 eval-logloss:0.06762
[56] train-logloss:0.02524 eval-logloss:0.06753
[57] train-logloss:0.02452 eval-logloss:0.06768
[58] train-logloss:0.02398 eval-logloss:0.06784
[59] train-logloss:0.02343 eval-logloss:0.06817
[60] train-logloss:0.02276 eval-logloss:0.06838
[61] train-logloss:0.02227 eval-logloss:0.06819
[62] train-logloss:0.02184 eval-logloss:0.06834
[63] train-logloss:0.02124 eval-logloss:0.06872
[64] train-logloss:0.02076 eval-logloss:0.06858
[65] train-logloss:0.02044 eval-logloss:0.06864
[66] train-logloss:0.02011 eval-logloss:0.06891
[67] train-logloss:0.01974 eval-logloss:0.06880
[68] train-logloss:0.01946 eval-logloss:0.06890
[69] train-logloss:0.01910 eval-logloss:0.06899
[70] train-logloss:0.01873 eval-logloss:0.06865
[71] train-logloss:0.01835 eval-logloss:0.06816
[72] train-logloss:0.01799 eval-logloss:0.06778
[73] train-logloss:0.01760 eval-logloss:0.06757
[74] train-logloss:0.01720 eval-logloss:0.06760
[75] train-logloss:0.01699 eval-logloss:0.06769
[76] train-logloss:0.01677 eval-logloss:0.06784
[77] train-logloss:0.01649 eval-logloss:0.06777
[78] train-logloss:0.01622 eval-logloss:0.06759
[79] train-logloss:0.01598 eval-logloss:0.06765
[80] train-logloss:0.01572 eval-logloss:0.06825
[81] train-logloss:0.01554 eval-logloss:0.06800
[82] train-logloss:0.01534 eval-logloss:0.06776
[83] train-logloss:0.01515 eval-logloss:0.06765
[84] train-logloss:0.01494 eval-logloss:0.06681
[85] train-logloss:0.01473 eval-logloss:0.06735
[86] train-logloss:0.01459 eval-logloss:0.06745
[87] train-logloss:0.01442 eval-logloss:0.06762
[88] train-logloss:0.01426 eval-logloss:0.06737
[89] train-logloss:0.01411 eval-logloss:0.06730
[90] train-logloss:0.01394 eval-logloss:0.06786
[91] train-logloss:0.01375 eval-logloss:0.06708
[92] train-logloss:0.01363 eval-logloss:0.06718
[93] train-logloss:0.01349 eval-logloss:0.06736
[94] train-logloss:0.01332 eval-logloss:0.06663
[95] train-logloss:0.01316 eval-logloss:0.06654
[96] train-logloss:0.01298 eval-logloss:0.06710
[97] train-logloss:0.01287 eval-logloss:0.06720
[98] train-logloss:0.01275 eval-logloss:0.06738
[99] train-logloss:0.01259 eval-logloss:0.06752
[100] train-logloss:0.01249 eval-logloss:0.06762
[101] train-logloss:0.01237 eval-logloss:0.06773
[102] train-logloss:0.01224 eval-logloss:0.06765
[103] train-logloss:0.01210 eval-logloss:0.06700
[104] train-logloss:0.01198 eval-logloss:0.06685
[105] train-logloss:0.01189 eval-logloss:0.06695
[106] train-logloss:0.01176 eval-logloss:0.06724
[107] train-logloss:0.01165 eval-logloss:0.06744
[108] train-logloss:0.01157 eval-logloss:0.06753
[109] train-logloss:0.01148 eval-logloss:0.06765
[110] train-logloss:0.01138 eval-logloss:0.06753
[111] train-logloss:0.01127 eval-logloss:0.06782
[112] train-logloss:0.01119 eval-logloss:0.06792
[113] train-logloss:0.01111 eval-logloss:0.06810
[114] train-logloss:0.01104 eval-logloss:0.06794
[115] train-logloss:0.01094 eval-logloss:0.06792
[116] train-logloss:0.01082 eval-logloss:0.06778
[117] train-logloss:0.01075 eval-logloss:0.06787
[118] train-logloss:0.01071 eval-logloss:0.06814
[119] train-logloss:0.01062 eval-logloss:0.06802
[120] train-logloss:0.01056 eval-logloss:0.06814
[121] train-logloss:0.01049 eval-logloss:0.06822
[122] train-logloss:0.01043 eval-logloss:0.06833
[123] train-logloss:0.01032 eval-logloss:0.06820
[124] train-logloss:0.01026 eval-logloss:0.06828
[125] train-logloss:0.01022 eval-logloss:0.06855
[126] train-logloss:0.01014 eval-logloss:0.06845
[127] train-logloss:0.01007 eval-logloss:0.06844
[128] train-logloss:0.00997 eval-logloss:0.06832
[129] train-logloss:0.00990 eval-logloss:0.06825
[130] train-logloss:0.00987 eval-logloss:0.06812
[131] train-logloss:0.00983 eval-logloss:0.06817
[132] train-logloss:0.00980 eval-logloss:0.06803
[133] train-logloss:0.00977 eval-logloss:0.06829
[134] train-logloss:0.00974 eval-logloss:0.06800
[135] train-logloss:0.00971 eval-logloss:0.06834
[136] train-logloss:0.00968 eval-logloss:0.06856
[137] train-logloss:0.00965 eval-logloss:0.06844
[138] train-logloss:0.00962 eval-logloss:0.06816
[139] train-logloss:0.00959 eval-logloss:0.06848
[140] train-logloss:0.00956 eval-logloss:0.06873
[141] train-logloss:0.00953 eval-logloss:0.06861
[142] train-logloss:0.00950 eval-logloss:0.06859
[143] train-logloss:0.00948 eval-logloss:0.06863
[144] train-logloss:0.00945 eval-logloss:0.06887
[145] train-logloss:0.00942 eval-logloss:0.06918
[146] train-logloss:0.00939 eval-logloss:0.06891
[147] train-logloss:0.00937 eval-logloss:0.06879
[148] train-logloss:0.00934 eval-logloss:0.06877
[149] train-logloss:0.00931 eval-logloss:0.06899
[150] train-logloss:0.00929 eval-logloss:0.06930
[151] train-logloss:0.00926 eval-logloss:0.06903
[152] train-logloss:0.00924 eval-logloss:0.06926
[153] train-logloss:0.00921 eval-logloss:0.06924
[154] train-logloss:0.00919 eval-logloss:0.06955
[155] train-logloss:0.00916 eval-logloss:0.06964
[156] train-logloss:0.00914 eval-logloss:0.06938
[157] train-logloss:0.00911 eval-logloss:0.06961
[158] train-logloss:0.00909 eval-logloss:0.06991
[159] train-logloss:0.00907 eval-logloss:0.06989
[160] train-logloss:0.00904 eval-logloss:0.06989
[161] train-logloss:0.00902 eval-logloss:0.06978
[162] train-logloss:0.00900 eval-logloss:0.06974
[163] train-logloss:0.00898 eval-logloss:0.06949
[164] train-logloss:0.00895 eval-logloss:0.06978
[165] train-logloss:0.00893 eval-logloss:0.07000
[166] train-logloss:0.00891 eval-logloss:0.06998
[167] train-logloss:0.00889 eval-logloss:0.06998
[168] train-logloss:0.00887 eval-logloss:0.06996
[169] train-logloss:0.00885 eval-logloss:0.07025
[170] train-logloss:0.00883 eval-logloss:0.07001
[171] train-logloss:0.00881 eval-logloss:0.07022
[172] train-logloss:0.00878 eval-logloss:0.07019
[173] train-logloss:0.00876 eval-logloss:0.07008
[174] train-logloss:0.00874 eval-logloss:0.07036
[175] train-logloss:0.00872 eval-logloss:0.07012
[176] train-logloss:0.00870 eval-logloss:0.07012
[177] train-logloss:0.00868 eval-logloss:0.07010
[178] train-logloss:0.00866 eval-logloss:0.07031
[179] train-logloss:0.00864 eval-logloss:0.06997
[180] train-logloss:0.00862 eval-logloss:0.06996
[181] train-logloss:0.00860 eval-logloss:0.06995
[182] train-logloss:0.00859 eval-logloss:0.06985
[183] train-logloss:0.00857 eval-logloss:0.07012
[184] train-logloss:0.00855 eval-logloss:0.07002
[185] train-logloss:0.00853 eval-logloss:0.06979
[186] train-logloss:0.00851 eval-logloss:0.06975
[187] train-logloss:0.00849 eval-logloss:0.06996
[188] train-logloss:0.00847 eval-logloss:0.06995
[189] train-logloss:0.00846 eval-logloss:0.06994
[190] train-logloss:0.00844 eval-logloss:0.07019
[191] train-logloss:0.00842 eval-logloss:0.07018
[192] train-logloss:0.00840 eval-logloss:0.06985
[193] train-logloss:0.00838 eval-logloss:0.07006
[194] train-logloss:0.00837 eval-logloss:0.06983
이진분류 모델이므로 objective는 binary:logistic으로 바꿔주고 eval_metric도 logloss로 바꿔줬다.
early_stopping_rounds로 조기종료를 사용했는데 이 경우 반드시 eval_set과 eval_metric을 함께 설정해줘야 한다. eval_set은 dtrain은 'trian', davl은 'eval'로 명기했다.
학습이 반복됨에 따라 trian-logloss와 eval-logloss가 모두 줄어든다. 이제 학습이 완료되었으니 predict 메서드를 이용해 예측을 수행하고 성능을 평가하면 끝이다.
# 예측하기, 확률값으로 반환됨
y_pred_probs = xgb_model.predict(dtest)
# 0또는 1로 변경
y_preds = [1 if x>0.5 else 0 for x in y_pred_probs]
#성능 평가
print(confusion_matrix(y_test, y_preds))
print(classification_report(y_test, y_preds))
print(roc_auc_score(y_test, y_preds))
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(xgb_model, ax=ax)
파이썬래퍼는 f1 score를 기반으로 각 feature의 중요도를 나타낸다. plot_importance()를 이용해서 바로 시각화가 가능하다.
사이킷런 래퍼 코드 실습
from xgboost import XGBClassifier
xgb = XGBClassifier(n_estimators = 400, learning_rate = 0.1, max_depth = 3)
xgb.fit(X_train, y_train)
y_preds = xgb.predict(X_test)
#성능 평가
print(confusion_matrix(y_test, y_preds))
print(classification_report(y_test, y_preds))
print(roc_auc_score(y_test, y_preds))
사이킷런 래퍼는 아주 간단하다. 사이킷런 모델을 학습, 예측할 때와 동일하게 fit, predict 메서드를 사용해주면 된다. GridSearchCV와 같은 사이킷런 유틸리티를 그대로 사용할 수 있으며 XGBClassifier와 XGBRegressor 분류 회귀 모두 제공한다.
다만 XGBoost는 훈련 시간이 길어 훈련 시간을 짧게 하면서 성능도 좋은 LightGBM을 더 많이 사용한다.
관련해서 자세한 내용은 아래 포스팅 참고
'파이썬 > 머신러닝' 카테고리의 다른 글
서포트벡터머신(SVM) 개념과 주요 파라미터 정의 (0) | 2021.07.27 |
---|---|
경사하강법(Gradient Descent)과 learning_rate 조정의 중요성 (0) | 2021.07.25 |
K-최근접 이웃 (K-NN) 분류기, 가장 간단한 머신러닝 알고리즘 (0) | 2021.07.21 |
light GBM이란? 파라미터 설명과 코드 실습 (2) | 2021.07.11 |
에이다부스트와 그라디언트부스팅 기본 개념과 코드, 부스팅과 배깅의 차이 (0) | 2021.06.22 |