How to Evaluate Machine Learning Models using Classification Metrics: ROC, AUC, Precision, Recall, and Beyond
Get reusable code for each evaluation metric
When we create machine learning models, we need to ensure they perform well in real-world scenarios. In this post, we'll dive into essential evaluation metrics and techniques that every data scientist should understand.
1. ROC and AUC
ROC (Receiver Operating Characteristic) and AUC (Area Under the Curve) help us see how good our model is at classifying things into two groups. The ROC curve is a graphical representation of a model's ability to distinguish between classes.
The curve plots True Positive Rate (TPR) against False Positive Rate (FPR) at various threshold settings.
True Positive Rate (TPR)
\(True Positive Rate (TPR) = \frac{\text{True Positives}}{\text{True Positives + False Negatives}}\)
False Positive Rate (FPR)
\(False Positive Rate (FPR) =\frac{\text{False Positives}}{\text{False Positives + True Negatives}}\)
import numpy as np
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
def plot_roc_curve(y_true, y_pred_proba):
fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()2. Sensitivity and Specificity
These two measures tell us:
Sensitivity - How well our model finds what we're looking for.
Specificty - How well it avoids false alarms.
Think of it like this:
Sensitivity: Out of all the actual positive cases, how many did we find?
Specificity: Out of all the actual negative cases, how many did we correctly identify?
Sensitivity (or Recall) and Specificity are critical metrics for evaluating classification models, especially in imbalanced datasets.
from sklearn.metrics import recall_score, precision_score
def calculate_sensitivity_specificity(y_true, y_pred):
# Sensitivity (same as recall)
sensitivity = recall_score(y_true, y_pred)
# Specificity
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp)
return sensitivity, specificity3. Precision and Recall
Precision measures the accuracy of positive predictions, while Recall (Sensitivity) measures the completeness of positive predictions.
Precision and recall are really important when we have uneven groups in our data.
from sklearn.metrics import precision_recall_curve
def plot_precision_recall_curve(y_true, y_pred_proba):
precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
plt.figure()
plt.plot(recall, precision, color='blue', lw=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()4. Cross Validation
Cross-validation is a technique to ensure that the model generalizes well to unseen data by splitting the dataset into training and testing subsets multiple times.
In other words, cross-validation is like giving your model multiple tests with different question sets. It helps us ensure that our model works well with new data.
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
def perform_cross_validation(model, X, y, k=5):
# Set up K-Fold cross validator
kf = KFold(n_splits=k, shuffle=True, random_state=42)
# Do cross validation
scores = cross_val_score(model, X, y, cv=kf, scoring='accuracy')
print(f"Test scores: {scores}")
print(f"Average score: {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")5. Confusion Matrix
The confusion matrix summarizes the performance of a classification algorithm by showing the counts of true positives, true negatives, false positives, and false negatives.
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(y_true, y_pred, labels=['Negative', 'Positive']):
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=labels, yticklabels=labels)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()
# Calculate metrics from confusion matrix
tn, fp, fn, tp = cm.ravel()
precision = tp / (tp + fp)
recall = tp / (tp + fn)
specificity = tn / (tn + fp)
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
print(f"Specificity: {specificity:.3f}")

