본문 바로가기

파이썬/머신러닝

XG Boost 파이썬 이해하기 - 파라미터 정리 및 파이썬래퍼와 사이킷런래퍼 코드 실습

XGBoost는 최적화된 그레디언트 부스팅 구현이 가능한 파이썬 라이브러리이다. XGBoost는 빠른 속도, 확장성, 이식성이 특징이며 캐글 등 머신러닝 경연 대회에서 우승 후보들이 사용하는 도구로 성능이 아주 좋다.

 

 

이전 모델이 과소적합한 샘플에 가중치를 줘서 다음 모델에서 예측 정확도를 높이는 방식으로 모델을 보완해가는 부스팅 기법을 사용한다.

부스팅, 배깅의 차이와 그라디언트 부스팅

 

 

gmb보다 빠르고 조기종료가 가능하며 과적합 방지가 가능하다. 또한 분류와 회귀 둘 다 사용이 가능하다. 사이킷런에서 제공하지 않기 때문에 따로 설치가 필요하다.

 

pip install xgboost

위 코드로 간단하게 설치 가능하다.

 


XGBoost는 파이썬래퍼 XGBoost와 사이킷런래퍼XGBoost 두가지 방법으로 사용 가능한데, 파이썬래퍼를 사용할 경우 train, test data를 위해 별도의 DMatrix를 생성해야 한다.

 

*DMatrix: 넘파이 입력 파라미터를 받아서 만들어지는 XGBoost만의 데이터셋

 

주요 파라미터

-Bosster Params (모델 조건 설정)

파이썬래퍼 사이킷런래퍼 설명
eta (0.3) learning_rate (0.1) 학습 단계별로 가중치를 얼마나 사용할 것인지 결정
값이 작을수록 다음 단계의 결과물 적게 반영
(일반적으로 0.01~0.2사이에서 결정)
num_boost_around (10) n_estimators (100) 생성할 weak learner의 수
min_child_weight (1) min_child_weight (1) 관측치에 대한 가중치합의 최소
과적합 조절 용도
gamma (0) min_split_loss (0) 리프노드의 추가 분할을 결정할 최소손실 감소값
값이 클수록 과적합 감소
max_depth (6) max_depth (3) 트리 깊이
sub_sample (1) subsample (1) 데이터 샘플링 비율 지정
일반적으로 0.5~1 사이 값 사용
colsample_bytree (1) colsample_bytree (1) 각 트리마다 피쳐 샘플링 비율
일반적으로 0.5~1 사이 값 사용
lambda (1) reg_lambda (1) L2규제(릿지) 가중치
클수록 과적합 감소
alpha (0) reg_alpah (0) L1규제(라쏘) 가중치
클수록 과적합 감소
scale_pos_weight (1) scale_pos_weight (1) 불균형 데이터셋의 균형 유지를 위해 사용

 

과적합 방지를 위한 파라미터 조정 방법

-eta 낮추기 (반대로 num_boost_round/n_estimators는 높여주기)

-max_depth 낮추기

-min_child_weight 높이기

-gamma 높이기

-subs_ample, colsample_bytree 낮추기

 

 

-Learning Task Params (학습 수행 시 객체함수, 평가 지표 설정)

objective -reg:linear : 회귀 (default)
-binary:logistict : 이진분류
-multi:softmax : 다중분류, 클래스 반환
-multi:softprob : 다중분류, 확률 반환
eval_metric  -목적함수에 따라 디폴트 값이 다름 (회귀 분석 : rmse, 클래스 분류: error)
-rmse 
-mae 
-logloss 
-error
-merror
-mlogloss
-auc

 

 

 

early_stopooing_rounds로 조기종료 기능을 수행하는데 이때 반드시 eval_set과 eval_metric으 같이 설정해줘야 한다.

*eval_set : 성능 평가를 위한 평가용 데이터셋

*eval_metric : 평가 셋에 적용할 성능 평가 방법

(반복할 때마다 eval_set에 eval_metirc의 평가지표로 예측 오류를 측정)

 

파이썬 래퍼 코드 실습
import xgboost as xgb
from xgboost import plot_importance ## Feature Importance를 불러오기 위함
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score

data = load_breast_cancer()
X_features = data.data
y_label = data.target

cancer = pd.DataFrame(data=X_features, columns = data.feature_names)
cancer['target'] = y_label
print(cancer.head())
print(cancer['target'].value_counts())
>>>1    357
   0    212
   Name: target, dtype: int64

 

1은 양성, 0은 음성으로 이진분류 모델이다.

 

 

앞서 말한 것처럼 파이썬래퍼로 모델을 학습시킬 경우 DMatrix 형태로 train, val, test set을 변환시켜줘야 한다.

 

#train,val,test split
X_train, X_test, y_train, y_test = train_test_split(X_features, y_label, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)

#DMatrix형태로 바꿔주기
dtrain = xgb.DMatrix(data=X_train, label=y_train)
dval = xgb.DMatrix(data=X_val, label=y_val)
dtest = xgb.DMatrix(data=X_test)

#모델 생성:예제가 이진분류이므로 목적함수(objective)는 binary:logistic(이진 로지스틱)
#오류 성능평가는 logloss
# num_boost_round 만큼 반복하는데 early_stopping_rounds 만큼 성능 향상이 없으면 중단
# early_stopping_rounds를 사용하려면 반드시 eval_set과 eval_metric이 함께 설정

params = {'max_depth' : 3,
         'eta' : 0.1, 
         'objective' : 'binary:logistic',
         'eval_metric' : 'logloss',
         'early_stoppings' : 100 }
         
 #훈련 set은 train, 평가set은 eval로 명기
xgb_model = xgb.train(params = params, dtrain = dtrain, num_boost_round = 400, 
                        early_stopping_rounds = 100, evals=[(dtrain,'train'),(dval,'eval')])
                        
>>>
[0]	train-logloss:0.61110	eval-logloss:0.61392
[1]	train-logloss:0.54569	eval-logloss:0.54705
[2]	train-logloss:0.48914	eval-logloss:0.49040
[3]	train-logloss:0.43928	eval-logloss:0.44400
[4]	train-logloss:0.39793	eval-logloss:0.40186
[5]	train-logloss:0.36053	eval-logloss:0.36743
[6]	train-logloss:0.32755	eval-logloss:0.33466
[7]	train-logloss:0.30003	eval-logloss:0.30660
[8]	train-logloss:0.27555	eval-logloss:0.28448
[9]	train-logloss:0.25406	eval-logloss:0.26223
[10]	train-logloss:0.23388	eval-logloss:0.24419
[11]	train-logloss:0.21688	eval-logloss:0.22693
[12]	train-logloss:0.20016	eval-logloss:0.21070
[13]	train-logloss:0.18490	eval-logloss:0.19531
[14]	train-logloss:0.17197	eval-logloss:0.18232
[15]	train-logloss:0.15976	eval-logloss:0.17115
[16]	train-logloss:0.14969	eval-logloss:0.16147
[17]	train-logloss:0.14022	eval-logloss:0.15193
[18]	train-logloss:0.13169	eval-logloss:0.14399
[19]	train-logloss:0.12290	eval-logloss:0.13664
[20]	train-logloss:0.11589	eval-logloss:0.12951
[21]	train-logloss:0.10925	eval-logloss:0.12318
[22]	train-logloss:0.10244	eval-logloss:0.11784
[23]	train-logloss:0.09700	eval-logloss:0.11318
[24]	train-logloss:0.09153	eval-logloss:0.10907
[25]	train-logloss:0.08691	eval-logloss:0.10518
[26]	train-logloss:0.08220	eval-logloss:0.10217
[27]	train-logloss:0.07829	eval-logloss:0.09938
[28]	train-logloss:0.07394	eval-logloss:0.09597
[29]	train-logloss:0.07034	eval-logloss:0.09340
[30]	train-logloss:0.06734	eval-logloss:0.09138
[31]	train-logloss:0.06367	eval-logloss:0.08916
[32]	train-logloss:0.06075	eval-logloss:0.08714
[33]	train-logloss:0.05812	eval-logloss:0.08535
[34]	train-logloss:0.05554	eval-logloss:0.08305
[35]	train-logloss:0.05271	eval-logloss:0.08174
[36]	train-logloss:0.05067	eval-logloss:0.07960
[37]	train-logloss:0.04836	eval-logloss:0.07736
[38]	train-logloss:0.04611	eval-logloss:0.07662
[39]	train-logloss:0.04443	eval-logloss:0.07508
[40]	train-logloss:0.04253	eval-logloss:0.07395
[41]	train-logloss:0.04079	eval-logloss:0.07352
[42]	train-logloss:0.03943	eval-logloss:0.07205
[43]	train-logloss:0.03793	eval-logloss:0.07180
[44]	train-logloss:0.03653	eval-logloss:0.07093
[45]	train-logloss:0.03526	eval-logloss:0.07048
[46]	train-logloss:0.03406	eval-logloss:0.06987
[47]	train-logloss:0.03284	eval-logloss:0.06952
[48]	train-logloss:0.03189	eval-logloss:0.06944
[49]	train-logloss:0.03105	eval-logloss:0.06965
[50]	train-logloss:0.03018	eval-logloss:0.06890
[51]	train-logloss:0.02931	eval-logloss:0.06899
[52]	train-logloss:0.02829	eval-logloss:0.06883
[53]	train-logloss:0.02746	eval-logloss:0.06835
[54]	train-logloss:0.02672	eval-logloss:0.06752
[55]	train-logloss:0.02589	eval-logloss:0.06762
[56]	train-logloss:0.02524	eval-logloss:0.06753
[57]	train-logloss:0.02452	eval-logloss:0.06768
[58]	train-logloss:0.02398	eval-logloss:0.06784
[59]	train-logloss:0.02343	eval-logloss:0.06817
[60]	train-logloss:0.02276	eval-logloss:0.06838
[61]	train-logloss:0.02227	eval-logloss:0.06819
[62]	train-logloss:0.02184	eval-logloss:0.06834
[63]	train-logloss:0.02124	eval-logloss:0.06872
[64]	train-logloss:0.02076	eval-logloss:0.06858
[65]	train-logloss:0.02044	eval-logloss:0.06864
[66]	train-logloss:0.02011	eval-logloss:0.06891
[67]	train-logloss:0.01974	eval-logloss:0.06880
[68]	train-logloss:0.01946	eval-logloss:0.06890
[69]	train-logloss:0.01910	eval-logloss:0.06899
[70]	train-logloss:0.01873	eval-logloss:0.06865
[71]	train-logloss:0.01835	eval-logloss:0.06816
[72]	train-logloss:0.01799	eval-logloss:0.06778
[73]	train-logloss:0.01760	eval-logloss:0.06757
[74]	train-logloss:0.01720	eval-logloss:0.06760
[75]	train-logloss:0.01699	eval-logloss:0.06769
[76]	train-logloss:0.01677	eval-logloss:0.06784
[77]	train-logloss:0.01649	eval-logloss:0.06777
[78]	train-logloss:0.01622	eval-logloss:0.06759
[79]	train-logloss:0.01598	eval-logloss:0.06765
[80]	train-logloss:0.01572	eval-logloss:0.06825
[81]	train-logloss:0.01554	eval-logloss:0.06800
[82]	train-logloss:0.01534	eval-logloss:0.06776
[83]	train-logloss:0.01515	eval-logloss:0.06765
[84]	train-logloss:0.01494	eval-logloss:0.06681
[85]	train-logloss:0.01473	eval-logloss:0.06735
[86]	train-logloss:0.01459	eval-logloss:0.06745
[87]	train-logloss:0.01442	eval-logloss:0.06762
[88]	train-logloss:0.01426	eval-logloss:0.06737
[89]	train-logloss:0.01411	eval-logloss:0.06730
[90]	train-logloss:0.01394	eval-logloss:0.06786
[91]	train-logloss:0.01375	eval-logloss:0.06708
[92]	train-logloss:0.01363	eval-logloss:0.06718
[93]	train-logloss:0.01349	eval-logloss:0.06736
[94]	train-logloss:0.01332	eval-logloss:0.06663
[95]	train-logloss:0.01316	eval-logloss:0.06654
[96]	train-logloss:0.01298	eval-logloss:0.06710
[97]	train-logloss:0.01287	eval-logloss:0.06720
[98]	train-logloss:0.01275	eval-logloss:0.06738
[99]	train-logloss:0.01259	eval-logloss:0.06752
[100]	train-logloss:0.01249	eval-logloss:0.06762
[101]	train-logloss:0.01237	eval-logloss:0.06773
[102]	train-logloss:0.01224	eval-logloss:0.06765
[103]	train-logloss:0.01210	eval-logloss:0.06700
[104]	train-logloss:0.01198	eval-logloss:0.06685
[105]	train-logloss:0.01189	eval-logloss:0.06695
[106]	train-logloss:0.01176	eval-logloss:0.06724
[107]	train-logloss:0.01165	eval-logloss:0.06744
[108]	train-logloss:0.01157	eval-logloss:0.06753
[109]	train-logloss:0.01148	eval-logloss:0.06765
[110]	train-logloss:0.01138	eval-logloss:0.06753
[111]	train-logloss:0.01127	eval-logloss:0.06782
[112]	train-logloss:0.01119	eval-logloss:0.06792
[113]	train-logloss:0.01111	eval-logloss:0.06810
[114]	train-logloss:0.01104	eval-logloss:0.06794
[115]	train-logloss:0.01094	eval-logloss:0.06792
[116]	train-logloss:0.01082	eval-logloss:0.06778
[117]	train-logloss:0.01075	eval-logloss:0.06787
[118]	train-logloss:0.01071	eval-logloss:0.06814
[119]	train-logloss:0.01062	eval-logloss:0.06802
[120]	train-logloss:0.01056	eval-logloss:0.06814
[121]	train-logloss:0.01049	eval-logloss:0.06822
[122]	train-logloss:0.01043	eval-logloss:0.06833
[123]	train-logloss:0.01032	eval-logloss:0.06820
[124]	train-logloss:0.01026	eval-logloss:0.06828
[125]	train-logloss:0.01022	eval-logloss:0.06855
[126]	train-logloss:0.01014	eval-logloss:0.06845
[127]	train-logloss:0.01007	eval-logloss:0.06844
[128]	train-logloss:0.00997	eval-logloss:0.06832
[129]	train-logloss:0.00990	eval-logloss:0.06825
[130]	train-logloss:0.00987	eval-logloss:0.06812
[131]	train-logloss:0.00983	eval-logloss:0.06817
[132]	train-logloss:0.00980	eval-logloss:0.06803
[133]	train-logloss:0.00977	eval-logloss:0.06829
[134]	train-logloss:0.00974	eval-logloss:0.06800
[135]	train-logloss:0.00971	eval-logloss:0.06834
[136]	train-logloss:0.00968	eval-logloss:0.06856
[137]	train-logloss:0.00965	eval-logloss:0.06844
[138]	train-logloss:0.00962	eval-logloss:0.06816
[139]	train-logloss:0.00959	eval-logloss:0.06848
[140]	train-logloss:0.00956	eval-logloss:0.06873
[141]	train-logloss:0.00953	eval-logloss:0.06861
[142]	train-logloss:0.00950	eval-logloss:0.06859
[143]	train-logloss:0.00948	eval-logloss:0.06863
[144]	train-logloss:0.00945	eval-logloss:0.06887
[145]	train-logloss:0.00942	eval-logloss:0.06918
[146]	train-logloss:0.00939	eval-logloss:0.06891
[147]	train-logloss:0.00937	eval-logloss:0.06879
[148]	train-logloss:0.00934	eval-logloss:0.06877
[149]	train-logloss:0.00931	eval-logloss:0.06899
[150]	train-logloss:0.00929	eval-logloss:0.06930
[151]	train-logloss:0.00926	eval-logloss:0.06903
[152]	train-logloss:0.00924	eval-logloss:0.06926
[153]	train-logloss:0.00921	eval-logloss:0.06924
[154]	train-logloss:0.00919	eval-logloss:0.06955
[155]	train-logloss:0.00916	eval-logloss:0.06964
[156]	train-logloss:0.00914	eval-logloss:0.06938
[157]	train-logloss:0.00911	eval-logloss:0.06961
[158]	train-logloss:0.00909	eval-logloss:0.06991
[159]	train-logloss:0.00907	eval-logloss:0.06989
[160]	train-logloss:0.00904	eval-logloss:0.06989
[161]	train-logloss:0.00902	eval-logloss:0.06978
[162]	train-logloss:0.00900	eval-logloss:0.06974
[163]	train-logloss:0.00898	eval-logloss:0.06949
[164]	train-logloss:0.00895	eval-logloss:0.06978
[165]	train-logloss:0.00893	eval-logloss:0.07000
[166]	train-logloss:0.00891	eval-logloss:0.06998
[167]	train-logloss:0.00889	eval-logloss:0.06998
[168]	train-logloss:0.00887	eval-logloss:0.06996
[169]	train-logloss:0.00885	eval-logloss:0.07025
[170]	train-logloss:0.00883	eval-logloss:0.07001
[171]	train-logloss:0.00881	eval-logloss:0.07022
[172]	train-logloss:0.00878	eval-logloss:0.07019
[173]	train-logloss:0.00876	eval-logloss:0.07008
[174]	train-logloss:0.00874	eval-logloss:0.07036
[175]	train-logloss:0.00872	eval-logloss:0.07012
[176]	train-logloss:0.00870	eval-logloss:0.07012
[177]	train-logloss:0.00868	eval-logloss:0.07010
[178]	train-logloss:0.00866	eval-logloss:0.07031
[179]	train-logloss:0.00864	eval-logloss:0.06997
[180]	train-logloss:0.00862	eval-logloss:0.06996
[181]	train-logloss:0.00860	eval-logloss:0.06995
[182]	train-logloss:0.00859	eval-logloss:0.06985
[183]	train-logloss:0.00857	eval-logloss:0.07012
[184]	train-logloss:0.00855	eval-logloss:0.07002
[185]	train-logloss:0.00853	eval-logloss:0.06979
[186]	train-logloss:0.00851	eval-logloss:0.06975
[187]	train-logloss:0.00849	eval-logloss:0.06996
[188]	train-logloss:0.00847	eval-logloss:0.06995
[189]	train-logloss:0.00846	eval-logloss:0.06994
[190]	train-logloss:0.00844	eval-logloss:0.07019
[191]	train-logloss:0.00842	eval-logloss:0.07018
[192]	train-logloss:0.00840	eval-logloss:0.06985
[193]	train-logloss:0.00838	eval-logloss:0.07006
[194]	train-logloss:0.00837	eval-logloss:0.06983

 

이진분류 모델이므로 objective는 binary:logistic으로 바꿔주고 eval_metric도 logloss로 바꿔줬다.

early_stopping_rounds로 조기종료를 사용했는데 이 경우 반드시 eval_set과 eval_metric을 함께 설정해줘야 한다. eval_set은 dtrain은 'trian', davl은 'eval'로 명기했다.

 

 

학습이 반복됨에 따라 trian-logloss와 eval-logloss가 모두 줄어든다. 이제 학습이 완료되었으니 predict 메서드를 이용해 예측을 수행하고 성능을 평가하면 끝이다.

 

# 예측하기, 확률값으로 반환됨
y_pred_probs = xgb_model.predict(dtest)

# 0또는 1로 변경
y_preds = [1 if x>0.5 else 0 for x in y_pred_probs]

#성능 평가
print(confusion_matrix(y_test, y_preds))
print(classification_report(y_test, y_preds))
print(roc_auc_score(y_test, y_preds))

 

from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(xgb_model, ax=ax)

파이썬래퍼는 f1 score를 기반으로 각 feature의 중요도를 나타낸다. plot_importance()를 이용해서 바로 시각화가 가능하다.

 

 

사이킷런 래퍼 코드 실습
from xgboost import XGBClassifier

xgb = XGBClassifier(n_estimators = 400, learning_rate = 0.1, max_depth = 3)
xgb.fit(X_train, y_train)
y_preds = xgb.predict(X_test)

#성능 평가
print(confusion_matrix(y_test, y_preds))
print(classification_report(y_test, y_preds))
print(roc_auc_score(y_test, y_preds))

 

사이킷런 래퍼는 아주 간단하다. 사이킷런 모델을 학습, 예측할 때와 동일하게 fit, predict 메서드를 사용해주면 된다. GridSearchCV와 같은 사이킷런 유틸리티를 그대로 사용할 수 있으며 XGBClassifier와 XGBRegressor 분류 회귀 모두 제공한다.

 

 

다만 XGBoost는 훈련 시간이 길어 훈련 시간을 짧게 하면서 성능도 좋은 LightGBM을 더 많이 사용한다.

관련해서 자세한 내용은 아래 포스팅 참고

LigthGBM 개념 알아보기