from sklearn.metrics import precision_recall_curve as prc
def plot_prc(
clf, Xtest, ytest,
precision = None, recall = None, threshold = None
):
pr, rc, ts = prc(ytest, clf.predict_proba(Xtest)[:,1])
pr, rc, ts = list(pr), list(rc), list(ts)
i = None
if precision is not None:
i = ts.index(min(t for t, p in zip(ts, pr) if p > precision))
elif recall is not None:
i = ts.index(max(t for t, r in zip(ts, rc) if r > recall))
elif threshold is not None:
i = ts.index(max(t for t in ts if t < threshold))
plt.plot(pr, rc, c='red', lw = 3)
if i is not None:
plt.plot([pr[i]], [rc[i]], marker = 'o', color = 'black')
plt.text(
pr[i], rc[i], '(%.2f, %.2f) ' % (pr[i], rc[i]),
fontdict = {'ha':'right', 'va':'center'}
)
plt.grid()
plt.xlim([0,1])
plt.ylim([0,1])
plt.title('Precision & Recall')
plt.xlabel('Precision')
plt.ylabel('Recall')
plt.show()