이번 시간에는 kNN(k-Nearest Neighbors)을 이용해서 붓꽃의 품종을 예측하는 코드를 소개해드리도록 하겠습니다. kNN은 한국어로 k-최근접이웃으로 번역되기도 합니다. kNN의 작동원리는 매우 간단합니다. 새로운 데이터가 있는데 그 데이터의 클래스를 모른다면, 훈련 데이터셋에서 그 새로운 데이터와 가장 비슷하게 생긴 놈을 찾아서 걔랑 같은 유형이라고 라벨링해버리는 것입니다. kNN에 대한 이론적인 내용은 이전 포스팅을 참고해주세요.^^
유유상종의 진리를 이용한 분류 모델, kNN(k-Nearest Neighbors)
필요 라이브러리 설치
오늘 예제 코드를 실행하기 위해서는 scikit-learn(사이킷런) 라이브러리와 NumPy 라이브러리가 필요합니다. 마찬가지로 각자 개발환경에 맞게 설치해주시면 됩니다.^^ 아마 대체적으로 다음과 같은 명령을 터미널에 입력하면 설치가 되지 않을까 싶네요.
pip install scikit-learn
pip install numpy
사이킷런은 안드레아스 뮐러라는 분이 주도해서 만든 머신러닝 라이브러리입니다. 머신러닝으로 데이터를 핸들링 할 때 많이 사용되는 라이브러리입니다.
kNN으로 붓꽃 품종 예측하기
붓꽃 품종에는 setosa, versicolor, virginica와 같은 세 개의 종이 있습니다. 붓꽃의 꽃잎의 폭과 길이, 꽃받침의 폭과 길이를 측정한 것을 가지고 어느 붓꽃 품종에 속하는지를 예측하는 예제입니다. 간단히 말해서 4개의 특성(feature)과 3개의 클래스(class)가 있는 문제입니다.
150개의 샘플을 포함하고 있는 붓꽃 데이터셋에서 랜덤하게 선택한 75%를 훈련셋으로 나머지 25%를 테스트셋으로 해서 예측 정확도를 평가해보도록 하겠습니다. 이것을 수행해주는 코드는 다음과 같습니다.
중요한 부분에는 한글로 주석을 달아놓았으니 참고하시기 바랍니다. kNN에서 중요한 매개변수인 k는 1로 설정했습니다. 테스트 데이터 포인트와 가장 가까이 있는 훈련 데이터 포인트의 클래스를 테스트 데이터 포인트의 클래스로 해주겠다는 의미입니다.
import numpy as np
from sklearn.datasets import load_iris
iris_dataset = load_iris() # 붓꽃 데이터셋을 적재합니다.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
# 데이터셋을 랜덤하게 75%의 훈련셋과 25%의 테스트셋으로 분리합니다.
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train, y_train) # kNN 분류기를 훈련셋으로 훈련시킵니다.
y_pred = knn.predict(X_test) # 테스트셋의 라벨값을 예측합니다.
print("prediction accuracy: {:.2f}".format(np.mean(y_pred == y_test)))
이 코드를 실행하면, 다음과 같은 결과가 출력됩니다.
예측 정확도가 0.97이라고 하니 꽤 높죠? 150개 중에 25%가 테스트셋이니까, 38개의 샘플들의 클래스 중 97%의 클래스를 맞췄다는 것입니다. 거의 한, 두 개 빼고는 다 제대로 예측해낸 것입니다.
오늘은 kNN을 이용해서 붓꽃의 품종을 예측하는 문제에 대해 다뤄봤습니다. 사실 이 문제는 머신러닝 분야에서 아주 유명하고도 오래된 예제입니다. 항상 질문과 지적은 환영합니다.^^ 담에 또 찾아뵙겠습니다!
'Dev > python' 카테고리의 다른 글
[ubuntu+python] 선형 회귀의 업그레이드 버전2, 라쏘 회귀 (2) | 2020.01.20 |
---|---|
[ubuntu+python] 선형 회귀의 업그레이드 버전1, 릿지 회귀 (0) | 2020.01.20 |
[ubuntu+python] pip install과 apt-get install의 차이는? pip와 pip3는 뭐가 다르지? sudo란? (2) | 2020.01.20 |
[ubuntu+python] 선형 회귀(linear regression) (0) | 2020.01.19 |
[ubuntu+python] 웹캠 영상 실시간 물체(객체) 검출 (38) | 2020.01.17 |
[ubuntu+python] YOLOv3으로 물체(객체) 검출하기 (33) | 2020.01.15 |
[ubuntu+python] 얼굴 검출 후 성별 인식 (4) | 2020.01.15 |
[ubuntu+python] 얼굴 검출 (2) | 2020.01.15 |