Rthur2003 commited on
Commit
058eadc
·
1 Parent(s): 27cd744

feat: implement optimal threshold calculation for binary predictions in train function

Browse files
Files changed (1) hide show
  1. app/training/train_classifier.py +3 -1
app/training/train_classifier.py CHANGED
@@ -170,7 +170,8 @@ def train(
170
  cv=cv,
171
  method="predict_proba",
172
  )[:, 1]
173
- y_pred = (y_prob >= 0.5).astype(int)
 
174
  cv_time = time.time() - t0
175
 
176
  acc = accuracy_score(y, y_pred)
@@ -194,6 +195,7 @@ def train(
194
  "recall": round(rec, 4),
195
  "f1": round(f1, 4),
196
  "roc_auc": round(auc, 4),
 
197
  "validation_auc": round(tuning_meta.get("validation_auc", 0.0), 4),
198
  "selection_time_sec": round(tuning_meta.get("selection_time_sec", 0.0), 2),
199
  "train_time_sec": round(cv_time, 2),
 
170
  cv=cv,
171
  method="predict_proba",
172
  )[:, 1]
173
+ threshold = _optimal_threshold(y, y_prob)
174
+ y_pred = (y_prob >= threshold).astype(int)
175
  cv_time = time.time() - t0
176
 
177
  acc = accuracy_score(y, y_pred)
 
195
  "recall": round(rec, 4),
196
  "f1": round(f1, 4),
197
  "roc_auc": round(auc, 4),
198
+ "optimal_threshold": round(threshold, 4),
199
  "validation_auc": round(tuning_meta.get("validation_auc", 0.0), 4),
200
  "selection_time_sec": round(tuning_meta.get("selection_time_sec", 0.0), 2),
201
  "train_time_sec": round(cv_time, 2),