Lý thuyết
3-4 gio
Bài 11/15

Threshold Analysis và Precision-Recall Curve

Phân tích Threshold, PR Curve va cách chọn threshold tối ưu

Threshold Analysis và Precision-Recall Curve

Mục tiêu bài học

Sau bài học này, học viên sẽ:

  • Hiểu ảnh hưởng của threshold đến Precision va Recall
  • Biet cách chọn threshold tối ưu theo từng bài toán
  • Nam vung Precision-Recall Curve va khi nao dung
  • Thực hành với Scikit-learn

1. Threshold trong Classification

1.1 Định nghĩa

Threshold la ngưỡng quyết định để chuyển probability thành class.

y^={1if P(y=1)threshold0if P(y=1)<threshold\hat{y} = \begin{cases} 1 & \text{if } P(y=1) \geq \text{threshold} \\ 0 & \text{if } P(y=1) < \text{threshold} \end{cases}

Default: threshold = 0.5

1.2 Trade-off Precision vs Recall

Thay đổiPrecisionRecallFPFN
Tăng ThresholdTăngGiảmGiảmTăng
Giảm ThresholdGiảmTăngTăngGiảm

2. Ví dụ tính toán thủ công

2.1 Dữ liệu

SampleTrue LabelP(Positive)
110.95
210.80
300.75
410.60
500.55
610.40
700.30
800.20

Tổng: 4 Positive, 4 Negative

2.2 Threshold = 0.7

Predict 0Predict 1
True 0TN=3FP=1
True 1FN=2TP=2

Precision=22+1=0.667Precision = \frac{2}{2+1} = 0.667 Recall=22+2=0.500Recall = \frac{2}{2+2} = 0.500

2.3 Threshold = 0.5

Predict 0Predict 1
True 0TN=2FP=2
True 1FN=1TP=3

Precision=33+2=0.600Precision = \frac{3}{3+2} = 0.600 Recall=33+1=0.750Recall = \frac{3}{3+1} = 0.750

2.4 Threshold = 0.35

Predict 0Predict 1
True 0TN=1FP=3
True 1FN=0TP=4

Precision=44+3=0.571Precision = \frac{4}{4+3} = 0.571 Recall=44+0=1.000Recall = \frac{4}{4+0} = 1.000

2.5 Tổng hop

ThresholdPrecisionRecallF1
0.700.6670.5000.571
0.500.6000.7500.667
0.350.5711.0000.727

3. Precision-Recall Curve

3.1 Khi nào dùng PR Curve thay vì ROC?

Tính huongĐúng
Data balancedROC Curve
Data highly imbalancedPR Curve
Positive class quan trọngPR Curve
Overall rankingROC Curve

3.2 Tại sao PR Curve tot hon cho Imbalanced Data?

Vi du: 990 Negative, 10 Positive

  • FPR = FP / (FP + TN) = 10 / 990 = 0.01 (rất nhỏ dù FP = 10)
  • Precision = TP / (TP + FP) = 5 / 15 = 0.33 (thể hiện rõ vấn đề)

Precision-Recall Curve

Hinh: Precision-Recall Curve tu Scikit-learn


4. Chọn Threshold tối ưu

4.1 Phương pháp 1: Tối đa F1-Score

Python
1from sklearn.metrics import precision_recall_curve
2import numpy as np
3
4precision, recall, thresholds = precision_recall_curve(y_test, y_prob)
5
6# Tính F1 cho mỗi threshold
7f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
8
9# Tìm threshold tối ưu
10optìmal_idx = np.argmax(f1_scores)
11optìmal_threshold = thresholds[optìmal_idx]
12optìmal_f1 = f1_scores[optìmal_idx]
13
14print(f"Optìmal Threshold: {optìmal_threshold:.4f}")
15print(f"F1 at optìmal: {optìmal_f1:.4f}")

4.2 Phương pháp 2: Youđến's J Statistic (cho ROC)

J=TPRFPR=Sensitivity+Specificity1J = TPR - FPR = Sensitivity + Specificity - 1

Python
1from sklearn.metrics import roc_curve
2
3fpr, tpr, thresholds = roc_curve(y_test, y_prob)
4
5# Youđến's J
6j_scores = tpr - fpr
7optìmal_idx = np.argmax(j_scores)
8optìmal_threshold = thresholds[optìmal_idx]
9
10print(f"Optìmal Threshold (Youđến's J): {optìmal_threshold:.4f}")

4.3 Phương pháp 3: Theo Business Requirement

Vi du y te - ưu tiên Recall >= 0.95:

Python
1# Tìm threshold thấp nhất để Recall >= 0.95
2for i, r in enumerate(recall):
3 if r >= 0.95:
4 threshold = thresholds[i]
5 prec = precision[i]
6 print(f"Threshold: {threshold:.4f}")
7 print(f"Recall: {r:.4f}, Precision: {prec:.4f}")
8 break

5. Ảnh hưởng của Threshold

5.1 Bảng tóm tắt

ThresholdPrecisionRecallKhi nào dùng
Cao (0.8)CaoThấpSpam filter, giảm FP
Trung bình (0.5)Cân bằngCân bằngGeneral purpose
Thấp (0.3)ThấpCaoY te, fraud, giảm FN

5.2 Tóm tắt Trade-off

Threshold cao (0.7-0.9):

  • Ít predict Positive
  • Precision cao, Recall thấp
  • Ít FP, nhiều FN
  • Đúng khi: FP cost cao (spam filter)

Threshold thấp (0.2-0.4):

  • Nhiều predict Positive
  • Precision thấp, Recall cao
  • Nhiều FP, ít FN
  • Đúng khi: FN cost cao (y te, fraud)

6. Thực hành với Scikit-learn

6.1 Code hoàn chỉnh

Python
1import numpy as np
2from sklearn.datasets import make_classification
3from sklearn.model_selection import train_test_split
4from sklearn.linear_model import LogisticRegression
5from sklearn.metrics import (precision_recall_curve, average_precision_score,
6 precision_score, recall_score, f1_score)
7import matplotlib.pyplot as plt
8
9# Tao imbalanced data
10X, y = make_classification(n_samples=1000, n_features=20,
11 weights=[0.9, 0.1], random_state=42)
12X_train, X_test, y_train, y_test = train_test_split(
13 X, y, test_size=0.3, random_state=42
14)
15
16# Train model
17model = LogisticRegression()
18model.fit(X_train, y_train)
19y_prob = model.predict_proba(X_test)[:, 1]
20
21# PR Curve
22precision, recall, thresholds = precision_recall_curve(y_test, y_prob)
23ap = average_precision_score(y_test, y_prob)
24
25# Ve PR Curve
26plt.figure(figsize=(10, 6))
27plt.plot(recall, precision, color='blue', lw=2,
28 label=f'PR curve (AP = {ap:.3f})')
29plt.axhline(y=sum(y_test)/len(y_test), color='gray',
30 linestyle='--', label='No skill')
31plt.xlabel('Recall')
32plt.ylabel('Precision')
33plt.title('Precision-Recall Curve')
34plt.legend()
35plt.grid(True)
36plt.show()
37
38# Tim optìmal threshold
39f1_scores = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-10)
40optìmal_idx = np.argmax(f1_scores)
41optìmal_threshold = thresholds[optìmal_idx]
42
43print(f"Optìmal Threshold: {optìmal_threshold:.4f}")
44print(f"Precision: {precision[optìmal_idx]:.4f}")
45print(f"Recall: {recall[optìmal_idx]:.4f}")
46print(f"F1: {f1_scores[optìmal_idx]:.4f}")

6.2 So sánh với default threshold

Python
1# Default threshold = 0.5
2y_pred_default = (y_prob >= 0.5).astype(int)
3print("\n=== Default Threshold (0.5) ===")
4print(f"Precision: {precision_score(y_test, y_pred_default):.4f}")
5print(f"Recall: {recall_score(y_test, y_pred_default):.4f}")
6print(f"F1: {f1_score(y_test, y_pred_default):.4f}")
7
8# Optìmal threshold
9y_pred_optìmal = (y_prob >= optìmal_threshold).astype(int)
10print(f"\n=== Optìmal Threshold ({optìmal_threshold:.4f}) ===")
11print(f"Precision: {precision_score(y_test, y_pred_optìmal):.4f}")
12print(f"Recall: {recall_score(y_test, y_pred_optìmal):.4f}")
13print(f"F1: {f1_score(y_test, y_pred_optìmal):.4f}")

7. Uu nhuoc điểm

PR Curve vs ROC Curve

AspectPR CurveROC Curve
Imbalanced dataTotCo the misleading
InterpretabilityTruc quan honPho bien hon
FocusPositive classCa 2 classes
BaselineTy le positiveDuong cheo

Bài tập tự luyện

  1. Bai tap 1: Tính Precision, Recall, F1 cho threshold = 0.3, 0.5, 0.7
  2. Bai tap 2: Ve PR Curve cho dataset co ty le 95% negative
  3. Bai tap 3: Tìm threshold de dat Recall >= 0.9 voi Precision cao nhat

Tài liệu tham khảo