Spaces:
Sleeping
Sleeping
| """ | |
| threshold_optimizer.py β Post-training threshold calibration tool. | |
| Run this standalone to re-optimize the probability threshold on new data | |
| WITHOUT retraining the model. Useful for: | |
| - Adapting to regime changes without full retraining | |
| - Testing different optimization objectives | |
| - Out-of-sample threshold validation | |
| The threshold search maximizes expectancy or Sharpe over a held-out dataset. | |
| Usage: | |
| python threshold_optimizer.py --symbols BTC-USDT ETH-USDT --bars 200 | |
| python threshold_optimizer.py --objective sharpe | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use("Agg") # non-interactive backend | |
| import matplotlib.pyplot as plt | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from ml_config import ( | |
| THRESHOLD_PATH, | |
| THRESHOLD_MIN, | |
| THRESHOLD_MAX, | |
| THRESHOLD_STEPS, | |
| THRESHOLD_OBJECTIVE, | |
| TARGET_RR, | |
| ROUND_TRIP_COST, | |
| FEATURE_COLUMNS, | |
| ML_DIR, | |
| ) | |
| from ml_filter import TradeFilter | |
| from feature_builder import build_feature_dict, validate_features | |
| from train import build_dataset | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| def compute_threshold_curve( | |
| probs: np.ndarray, | |
| y_true: np.ndarray, | |
| rr: float = TARGET_RR, | |
| cost: float = ROUND_TRIP_COST, | |
| ) -> pd.DataFrame: | |
| """ | |
| Sweep threshold grid and compute metrics at each threshold. | |
| Returns DataFrame for analysis and plotting. | |
| """ | |
| thresholds = np.linspace(THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS) | |
| records = [] | |
| for t in thresholds: | |
| mask = probs >= t | |
| n = int(mask.sum()) | |
| if n < 5: | |
| records.append({ | |
| "threshold": t, "n_trades": n, | |
| "win_rate": np.nan, "expectancy": np.nan, | |
| "sharpe": np.nan, "precision": np.nan, | |
| "coverage": 0.0, | |
| }) | |
| continue | |
| y_f = y_true[mask] | |
| wr = float(y_f.mean()) | |
| exp = wr * rr - (1 - wr) * 1.0 - cost | |
| pnl = np.where(y_f == 1, rr, -1.0) - cost | |
| sh = (pnl.mean() / pnl.std() * np.sqrt(252)) if pnl.std() > 1e-9 else 0.0 | |
| cov = n / len(y_true) | |
| records.append({ | |
| "threshold": round(t, 4), | |
| "n_trades": n, | |
| "win_rate": round(wr, 4), | |
| "expectancy": round(exp, 4), | |
| "sharpe": round(sh, 4), | |
| "precision": round(wr, 4), | |
| "coverage": round(cov, 4), | |
| }) | |
| return pd.DataFrame(records) | |
| def find_optimal_threshold( | |
| curve: pd.DataFrame, | |
| objective: str = THRESHOLD_OBJECTIVE, | |
| min_trades: int = 20, | |
| ) -> float: | |
| valid = curve[curve["n_trades"] >= min_trades].dropna(subset=[objective]) | |
| if valid.empty: | |
| logger.warning("No valid threshold found β using default 0.55") | |
| return 0.55 | |
| best_row = valid.loc[valid[objective].idxmax()] | |
| return float(best_row["threshold"]) | |
| def plot_threshold_curves(curve: pd.DataFrame, optimal: float, save_path: Path): | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 8)) | |
| fig.suptitle("Threshold Optimization", fontsize=14, fontweight="bold") | |
| metrics = ["expectancy", "sharpe", "win_rate", "n_trades"] | |
| titles = ["Expectancy per Trade", "Annualized Sharpe", "Win Rate", "# Trades"] | |
| for ax, metric, title in zip(axes.flatten(), metrics, titles): | |
| valid = curve.dropna(subset=[metric]) | |
| ax.plot(valid["threshold"], valid[metric], lw=2, color="#1a6bff") | |
| ax.axvline(optimal, color="orange", linestyle="--", lw=1.5, label=f"Optimal={optimal:.3f}") | |
| ax.axhline(0, color="gray", linestyle=":", lw=0.8) | |
| ax.set_title(title, fontsize=11) | |
| ax.set_xlabel("Threshold") | |
| ax.legend(fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=120, bbox_inches="tight") | |
| plt.close() | |
| logger.info(f"Threshold curve plot saved β {save_path}") | |
| def main(args): | |
| trade_filter = TradeFilter.load_or_none() | |
| if trade_filter is None: | |
| logger.error("No trained model found. Run train.py first.") | |
| sys.exit(1) | |
| symbols = args.symbols or ["BTC-USDT", "ETH-USDT", "SOL-USDT", "BNB-USDT"] | |
| dataset = build_dataset(symbols, bars=args.bars) | |
| X = dataset[FEATURE_COLUMNS].values.astype(np.float64) | |
| y = dataset["label"].values.astype(np.int32) | |
| feature_dicts = [ | |
| {k: float(row[k]) for k in FEATURE_COLUMNS} | |
| for _, row in dataset[FEATURE_COLUMNS].iterrows() | |
| ] | |
| probs = trade_filter.predict_batch(feature_dicts) | |
| logger.info(f"Generated {len(probs)} predictions | mean_prob={probs.mean():.4f}") | |
| curve = compute_threshold_curve(probs, y) | |
| optimal = find_optimal_threshold(curve, objective=args.objective) | |
| best_row = curve[curve["threshold"].round(4) == round(optimal, 4)].iloc[0] | |
| logger.info(f"\n=== THRESHOLD OPTIMIZATION RESULT ===") | |
| logger.info(f" Objective: {args.objective}") | |
| logger.info(f" Optimal threshold: {optimal:.4f}") | |
| logger.info(f" Win rate: {best_row['win_rate']:.4f}") | |
| logger.info(f" Expectancy: {best_row['expectancy']:.4f}") | |
| logger.info(f" Sharpe: {best_row['sharpe']:.4f}") | |
| logger.info(f" # Trades: {int(best_row['n_trades'])}") | |
| logger.info(f" Coverage: {best_row['coverage']:.2%}") | |
| # Update threshold file | |
| ML_DIR.mkdir(parents=True, exist_ok=True) | |
| thresh_data = { | |
| "threshold": optimal, | |
| "objective": args.objective, | |
| "win_rate_at_threshold": float(best_row["win_rate"]), | |
| "expectancy_at_threshold": float(best_row["expectancy"]), | |
| "sharpe_at_threshold": float(best_row["sharpe"]), | |
| "n_trades_at_threshold": int(best_row["n_trades"]), | |
| } | |
| with open(THRESHOLD_PATH, "w") as f: | |
| json.dump(thresh_data, f, indent=2) | |
| logger.info(f"Threshold updated β {THRESHOLD_PATH}") | |
| # Save curve CSV | |
| curve_path = ML_DIR / "threshold_curve.csv" | |
| curve.to_csv(curve_path, index=False) | |
| # Plot | |
| plot_path = ML_DIR / "threshold_curve.png" | |
| try: | |
| plot_threshold_curves(curve, optimal, plot_path) | |
| except Exception as e: | |
| logger.warning(f"Plot failed: {e}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Optimize probability threshold") | |
| parser.add_argument("--symbols", nargs="+", default=None) | |
| parser.add_argument("--bars", type=int, default=200) | |
| parser.add_argument("--objective", choices=["expectancy", "sharpe", "win_rate"], default=THRESHOLD_OBJECTIVE) | |
| args = parser.parse_args() | |
| main(args) | |