1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | # -*- coding: utf-8 -*- """Day 14_MulticlassSVM.ipynb Automatically generated by Colaboratory. Original file is located at """ from google.colab import drive drive.mount('/gdrive') PATH = "/gdrive/My Drive/Colab Notebooks/resources/" # %matplotlib inline import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl import pandas as pd import seaborn as sns import time from sklearn.svm import SVC, LinearSVC def run_multiclass_svm(datafile,C_value=1.0): data = np.loadtxt(datafile) n,d = data.shape ## 트레이닝 데이터와 타겟 데이터 분리 x = data[:,0:2] y = data[:,2] k = int(max(y)) + 1 print ("Number of classes: ", k) # 트레이닝 clf = LinearSVC(loss='hinge', multi_class='crammer_singer', C=C_value) clf.fit(x,y) # 결정경계 시각화 pred_fn = lambda p: clf.predict(p.reshape(1,-1)) display_data_and_boundary(x,y,pred_fn) def display_data_and_boundary(x,y,pred_fn): # Determine the x1- and x2- limits of the plot x1min = min(x[:,0]) - 1 x1max = max(x[:,0]) + 1 x2min = min(x[:,1]) - 1 x2max = max(x[:,1]) + 1 plt.xlim(x1min,x1max) plt.ylim(x2min,x2max) # Plot the data points k = int(max(y)) + 1 cols = ['ro', 'k^', 'b*','gx'] for label in range(k): plt.plot(x[(y==label),0], x[(y==label),1], cols[label%4], markersize=8) # Construct a grid of points at which to evaluate the classifier grid_spacing = 0.05 ## 가로 0~10 세로 0~10까지 0.05를 기준으로 모든 x 좌표값 , y 좌표값 생성 ## xx1 은 200 x 200 매트릭스에서 x1 변수의 자리수. 즉 x좌표, ## xx2 는 200 x 200 매트릭스에서 x2 변수의 자리수, 즉 y좌표 만을 가지고 있다 xx1, xx2 = np.meshgrid(np.arange(x1min, x1max, grid_spacing), np.arange(x2min, x2max, grid_spacing)) ## 그리드 내의 0.05 간격 모든 좌표값 200 x 200 요소를 가진 매트릭스의 각 좌표값을 가지고 있는 리스트가 된다. grid = np.c_[xx1.ravel(), xx2.ravel()] ## 각 좌표에서 내가 구분한 값의 클래스를 담는 변수 Z Z = np.array([pred_fn(pt) for pt in grid]) # Show the classifier's boundary using a color plot Z = Z.reshape(xx1.shape) plt.pcolormesh(xx1, xx2, Z, cmap=plt.cm.Pastel1, vmin=0, vmax=k) plt.show() run_multiclass_svm(PATH+'multiclassSVM/data_3.txt',10.0) | cs |
출처 및 참고자료 : edx - Machine Learning Fundamentals_week_6 Programming Assignment.3
'Python Library > Machine Learning' 카테고리의 다른 글
[NN] Day 01_introduction With MNIST (0) | 2019.07.23 |
---|---|
Day 10_PCA_MNIST (0) | 2019.07.21 |
Day 09_Multiclass_PerceptronClassifier (0) | 2019.07.19 |
Day 09_SVM_Sentiment_Analysis (0) | 2019.07.19 |
Day 08. Perceptron_Classification_Algorithm (0) | 2019.07.18 |