aurora1625
3/20/2019 - 8:30 AM

precision-recall curve

from sklearn.metrics import precision_recall_curve as prc
def plot_prc(
    clf, Xtest, ytest, 
    precision = None, recall = None, threshold = None
):
    pr, rc, ts = prc(ytest, clf.predict_proba(Xtest)[:,1])
    pr, rc, ts = list(pr), list(rc), list(ts)
    i = None
    if precision is not None:
        i = ts.index(min(t for t, p in zip(ts, pr) if p > precision))
    elif recall is not None:
        i = ts.index(max(t for t, r in zip(ts, rc) if r > recall))
    elif threshold is not None:
        i = ts.index(max(t for t in ts if t < threshold))
    plt.plot(pr, rc, c='red', lw = 3)
    if i is not None:
        plt.plot([pr[i]], [rc[i]], marker = 'o', color = 'black')
        plt.text(
            pr[i], rc[i], '(%.2f, %.2f) ' % (pr[i], rc[i]),
            fontdict = {'ha':'right', 'va':'center'}
        )
    plt.grid()
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.title('Precision & Recall')
    plt.xlabel('Precision')
    plt.ylabel('Recall')
    plt.show()