karfly
4/10/2017 - 2:05 PM

svm.py

def plot_svm(x, y, svm_estimator, ax, title=''):
    x_min, x_max = np.min(x[:, 0]), np.max(x[:, 0])
    y_min, y_max = np.min(x[:, 1]), np.max(x[:, 1])

    
    # Put the result into a color plot
    XX, YY = np.mgrid[x_min:x_max:400j, y_min:y_max:400j]
    
    Z = svm_estimator.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
    ax.pcolormesh(XX, YY, Z, cmap=plt.cm.Paired)
    ax.contour(XX, YY, Z, 1, colors='black')
    
    Z = (svm_estimator.decision_function(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) >= 1).astype('float')
    Z[0][0] = 1e-6  # fixing 1-class problem
    ax.contour(XX , YY, Z, 1, colors='black', linestyles='dashed')
    
    Z = (svm_estimator.decision_function(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) <= -1).astype('float')
    Z[0][0] = 1e-6  # fixing 1-class problem
    ax.contour(XX, YY, Z, 1, colors='black', linestyles='dashed')
    
    # Plotting points
    ax.scatter(svm_estimator.support_vectors_[:, 0], svm_estimator.support_vectors_[:, 1],
               s=100, color='black', facecolors='none', zorder=2)
    ax.scatter(x[:, 0], x[:, 1], c=y, s=50, cmap='magma', zorder=3)
    
    ax.set_title(title)
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

    ax.set_xticks([])
    ax.set_yticks([])