본문 바로가기
CS/머신러닝

혼자 공부하는 머신러닝 + 딥러닝 - Ch 5-1

by jaehoonChoi 2023. 1. 13.

Ch 5-1에서는 결정 트리에 대해 공부합니다.

 

[ 문제 ]

와인의 데이터가 주어지고 target 데이터로 0(레드)과 1(화이트)이 주어집니다.

이들을 분류하는 기준을 알기 쉽게 만드는게 목표입니다. 

 

[ 접근 ] 

기존의 머신러닝 모델은 복잡한 식으로 구성되어 있어 어떤 기준으로 나뉘는지 파악하기 어렵습니다.

결정트리는 이런 경우 해결책을 줍니다. 

 

 

[ 데이터 분석 ]

와인 데이터를 불러옵시다.

import numpy as np
import pandas as pd
wine=pd.read_csv("https://bit.ly/wine_csv_data")
wine.head()

 

	alcohol	sugar	pH	class
0	9.4	1.9	3.51	0.0
1	9.8	2.6	3.20	0.0
2	9.8	2.3	3.26	0.0
3	9.8	1.9	3.16	0.0
4	9.4	1.9	3.51	0.0

 

이제 데이터를 전처리하고 분리하는 작업을 진행합니다.

from pandas.core.common import standardize_mapping
data=wine[['alcohol', 'sugar', 'pH']].to_numpy()
target=wine['class'].to_numpy()
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test=train_test_split(
    data, target, test_size=0.2, random_state=42
)
print(X_train.shape, X_test.shape)

from sklearn.preprocessing import StandardScaler
ss=StandardScaler()
ss.fit(X_train)
X_train=ss.transform(X_train)
X_test=ss.transform(X_test)

 

[ Decision Tree ] 

결정 트리를 이용해봅시다.

분류용 결정트리는 사이킷런 tree에서 DecisionTreeClassifier를 이용합니다.

from sklearn.tree import DecisionTreeClassifier
dt=DecisionTreeClassifier()
dt.fit(X_train, y_train)
print(dt.score(X_train, y_train))
print(dt.score(X_test, y_test))

 

0.996921300750433
0.8592307692307692

훈련성능은 매우 뛰어난 반면 테스트 성능은 좀 부족하네요. 과대적합된 모델임을 알 수 있습니다.

이렇게 만들어진 트리 그림을 출력해보면.. 

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize=(10, 7))
plot_tree(dt)
plt.show()

분류기준을 알려고 결정트리를 사용한건데 이건 뭐 어떻게 나뉘는지 알 도리가 없네요. 

깊이를 설정해놓지 않아서 그냥 최대한 훈련 성능을 높이는 쪽으로만 깊게 파고들어서 그렇습니다.

최대깊이를 3으로 하고 다시 돌려봅시다. 

dt=DecisionTreeClassifier(max_depth=3)
dt.fit(X_train, y_train)
print(dt.score(X_train, y_train))
print(dt.score(X_test, y_test))

 

0.8454877814123533
0.8415384615384616

테스트 성능은 거의 그대로이지만 훈련 성능이 아까보단 낮아졌습니다. 과적합 문제는 해결한 것 같습니다.

이제 트리 그림을 출력해봅시다.

plt.figure(figsize=(20, 10))
plot_tree(dt, filled=True, feature_names=['alcohol', 'sugar', 'pH'])
plt.show()

 

와! 원하던 결과가 나왔습니다. 루트노드 쪽에선 당도를 기준으로, 더 내려갈수록

알코올 농도나 pH에 따라 분류해주고 있음을 알 수 있습니다. 이렇게 루트노드 쪽에서 먼저 기준을

나눠주는 특성이 데이터를 분류할 때 중요하게 고려되는 속성이라고 할 수 있습니다.

결정트리는 특성 중요도도 계산해주는데, 한번 출력해봅시다.

print(dt.feature_importances_)

 

[0.12345626 0.86862934 0.0079144 ]

예상한 바와 같이 2번째 특성 당도가 가장 중요한 기준임을 알 수 있습니다.

이렇게 결정트리는 분류기준을 이해하기 쉽게 만들어주는 모델 중 하나입니다.

다만 성능이 살짝 아쉬웠는데 이후 챕터에서 성능개선의 방법을 배웁니다.

 

 

댓글