Spaces:
Sleeping
Sleeping
| """ | |
| train.py β Full training pipeline. Run this script to train the model. | |
| Usage: | |
| python train.py --symbols BTC-USDT ETH-USDT SOL-USDT ... --bars 500 | |
| python train.py --use-defaults --bars 300 | |
| python train.py --data-dir ./historical_csv # load pre-saved CSVs | |
| Pipeline: | |
| 1. Fetch OHLCV for all symbols | |
| 2. Run rule engine to extract features (no lookahead) | |
| 3. Label each signal bar with forward-looking outcome | |
| 4. Concatenate all symbols (adds cross-asset diversity) | |
| 5. Walk-forward validation β choose threshold | |
| 6. Final model fit on full dataset | |
| 7. Save model + threshold + feature importances | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from config import DEFAULT_SYMBOLS, TIMEFRAME, CANDLE_LIMIT | |
| from data_fetcher import fetch_multiple | |
| from regime import detect_regime | |
| from volume_analysis import analyze_volume | |
| from scorer import compute_structure_score, score_token | |
| from veto import apply_veto | |
| from feature_builder import build_feature_dict, validate_features | |
| from labeler import label_dataframe, compute_label_stats | |
| from walk_forward import run_walk_forward, summarize_walk_forward | |
| from model_backend import ModelBackend | |
| from ml_config import ( | |
| ML_DIR, | |
| MODEL_PATH, | |
| THRESHOLD_PATH, | |
| FEATURE_IMP_PATH, | |
| LABEL_PATH, | |
| LGBM_PARAMS, | |
| FEATURE_COLUMNS, | |
| LABEL_FORWARD_BARS, | |
| THRESHOLD_MIN, | |
| THRESHOLD_MAX, | |
| THRESHOLD_STEPS, | |
| THRESHOLD_OBJECTIVE, | |
| STOP_MULT, | |
| ) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| stream=sys.stdout, | |
| ) | |
| logger = logging.getLogger("train") | |
| def infer_direction(trend: str, breakout: int) -> int: | |
| if trend == "bullish" or breakout == 1: | |
| return 1 | |
| if trend == "bearish" or breakout == -1: | |
| return -1 | |
| return 0 | |
| def extract_features_and_labels( | |
| symbol: str, | |
| df: pd.DataFrame, | |
| ) -> pd.DataFrame: | |
| """ | |
| Run the full rule engine over a DataFrame, bar by bar (forward-only). | |
| Returns a DataFrame with feature columns + 'label' + 'direction' + 'timestamp'. | |
| Implementation note: we compute regime/volume/scores using the full | |
| historical series up to each bar β no information from future bars | |
| is ever used. The label is computed separately using FORWARD bars only. | |
| """ | |
| if len(df) < 60: | |
| logger.warning(f"{symbol}: too short ({len(df)} bars), skipping") | |
| return pd.DataFrame() | |
| # Compute full-series regime and volume (these use only past data internally) | |
| try: | |
| regime_data = detect_regime(df) | |
| volume_data = analyze_volume(df, atr_series=regime_data["atr_series"]) | |
| except Exception as e: | |
| logger.error(f"{symbol}: rule engine error: {e}") | |
| return pd.DataFrame() | |
| atr_series = regime_data["atr_series"] | |
| # Build per-bar feature rows for all bars with valid ATR (skip first ATR_PERIOD) | |
| rows = [] | |
| n = len(df) | |
| for i in range(30, n): | |
| # Slice up to bar i (inclusive) β simulate running bar by bar | |
| df_i = df.iloc[: i + 1] | |
| try: | |
| r_i = detect_regime(df_i) | |
| v_i = analyze_volume(df_i, atr_series=r_i["atr_series"]) | |
| except Exception: | |
| continue | |
| sc_i = compute_structure_score(r_i) | |
| direction = infer_direction(r_i["trend"], v_i["breakout"]) | |
| vetoed, _ = apply_veto(r_i, v_i, sc_i, direction=direction) | |
| # Only label bars that the rule engine would have flagged as signals | |
| is_signal = not vetoed and r_i["regime_confidence"] > 0.3 | |
| scores = score_token(r_i, v_i, vetoed=False) # compute scores even if vetoed | |
| try: | |
| feat = build_feature_dict(r_i, v_i, scores) | |
| except (KeyError, ValueError): | |
| continue | |
| if not validate_features(feat): | |
| continue | |
| feat["_symbol"] = symbol | |
| feat["_bar_idx"] = i | |
| feat["_timestamp"] = df.index[i] | |
| feat["_is_signal"] = int(is_signal) | |
| feat["_direction"] = direction | |
| feat["_atr"] = float(r_i["atr"]) | |
| rows.append(feat) | |
| if not rows: | |
| return pd.DataFrame() | |
| result = pd.DataFrame(rows) | |
| # Label: compute forward outcomes for signal bars | |
| signal_mask_full = pd.Series(False, index=df.index) | |
| direction_full = pd.Series(0, index=df.index) | |
| atr_full = atr_series | |
| for row in rows: | |
| if row["_is_signal"]: | |
| idx = df.index[row["_bar_idx"]] | |
| signal_mask_full[idx] = True | |
| direction_full[idx] = row["_direction"] | |
| labels = label_dataframe( | |
| df=df, | |
| signal_mask=signal_mask_full, | |
| atr_series=atr_full, | |
| direction_series=direction_full, | |
| forward_bars=LABEL_FORWARD_BARS, | |
| ) | |
| # Merge labels back into result | |
| result = result.set_index("_timestamp") | |
| result["label"] = labels.reindex(result.index) | |
| result = result.reset_index().rename(columns={"index": "_timestamp"}) | |
| # Keep only signal bars with valid labels | |
| result = result[result["_is_signal"] == 1].copy() | |
| result = result.dropna(subset=["label"]) | |
| result["label"] = result["label"].astype(int) | |
| logger.info( | |
| f"{symbol}: {len(result)} labeled signals β " | |
| f"wr={result['label'].mean():.3f}" | |
| ) | |
| return result | |
| def build_dataset( | |
| symbols: list, | |
| bars: int = CANDLE_LIMIT, | |
| data_dir: Path = None, | |
| ) -> pd.DataFrame: | |
| """Fetch data and build full labeled feature dataset.""" | |
| all_frames = [] | |
| if data_dir and data_dir.exists(): | |
| logger.info(f"Loading CSVs from {data_dir}") | |
| for csv_path in sorted(data_dir.glob("*.csv")): | |
| sym = csv_path.stem | |
| df = pd.read_csv(csv_path, index_col=0, parse_dates=True) | |
| df.index = pd.to_datetime(df.index, utc=True) | |
| df.sort_index(inplace=True) | |
| frame = extract_features_and_labels(sym, df) | |
| if not frame.empty: | |
| all_frames.append(frame) | |
| else: | |
| logger.info(f"Fetching OHLCV for {len(symbols)} symbols ({bars} bars each)") | |
| ohlcv_map = fetch_multiple(symbols, limit=bars, min_bars=60) | |
| for sym, df in ohlcv_map.items(): | |
| frame = extract_features_and_labels(sym, df) | |
| if not frame.empty: | |
| all_frames.append(frame) | |
| if not all_frames: | |
| raise ValueError("No labeled data produced. Check symbols and API connectivity.") | |
| combined = pd.concat(all_frames, ignore_index=True) | |
| combined.sort_values("_timestamp", inplace=True) | |
| combined.reset_index(drop=True, inplace=True) | |
| logger.info( | |
| f"Dataset: {len(combined)} samples across {len(all_frames)} symbols | " | |
| f"overall wr={combined['label'].mean():.3f}" | |
| ) | |
| return combined | |
| def fit_final_model( | |
| X: np.ndarray, | |
| y: np.ndarray, | |
| params: dict, | |
| val_frac: float = 0.15, | |
| ) -> ModelBackend: | |
| """Fit final model on full dataset with internal validation split.""" | |
| split = int(len(X) * (1 - val_frac)) | |
| X_tr, y_tr = X[:split], y[:split] | |
| X_va, y_va = X[split:], y[split:] | |
| pos_frac = y_tr.mean() | |
| sample_weight = None | |
| if 0.05 < pos_frac < 0.95: | |
| sample_weight = np.where(y_tr == 1, 1.0 / pos_frac, 1.0 / (1 - pos_frac)) | |
| backend = ModelBackend(params=params, calibrate=True) | |
| backend.fit(X_tr, y_tr, X_va, y_va, sample_weight=sample_weight) | |
| logger.info(f"Final model: {backend.n_iter_} boosting rounds, backend={backend.backend_name}") | |
| return backend | |
| def save_artifacts( | |
| backend: ModelBackend, | |
| threshold: float, | |
| summary: dict, | |
| dataset: pd.DataFrame, | |
| ): | |
| import joblib | |
| ML_DIR.mkdir(parents=True, exist_ok=True) | |
| # Save model | |
| joblib.dump(backend, MODEL_PATH) | |
| logger.info(f"Model saved β {MODEL_PATH}") | |
| # Save threshold | |
| thresh_data = { | |
| "threshold": threshold, | |
| "objective": THRESHOLD_OBJECTIVE, | |
| "n_folds_used": summary.get("n_folds", 0), | |
| "mean_test_expectancy": summary.get("mean_expectancy"), | |
| "mean_test_sharpe": summary.get("mean_sharpe"), | |
| "mean_test_precision": summary.get("mean_precision"), | |
| } | |
| with open(THRESHOLD_PATH, "w") as f: | |
| json.dump(thresh_data, f, indent=2) | |
| logger.info(f"Threshold saved β {THRESHOLD_PATH} (value={threshold:.4f})") | |
| # Save feature importances | |
| imp_df = pd.DataFrame({ | |
| "feature": FEATURE_COLUMNS, | |
| "importance": backend.feature_importances_, | |
| }).sort_values("importance", ascending=False) | |
| imp_df.to_csv(FEATURE_IMP_PATH, index=False) | |
| logger.info(f"Feature importances saved β {FEATURE_IMP_PATH}") | |
| # Save label stats | |
| label_stats = compute_label_stats(pd.Series(dataset["label"].values)) | |
| with open(LABEL_PATH, "w") as f: | |
| json.dump(label_stats, f, indent=2) | |
| logger.info(f"Label stats: {label_stats}") | |
| def main(args): | |
| logger.info("=" * 60) | |
| logger.info("OKX TRADE FILTER β TRAINING PIPELINE") | |
| logger.info("=" * 60) | |
| if args.use_defaults: | |
| symbols = DEFAULT_SYMBOLS | |
| elif args.symbols: | |
| symbols = args.symbols | |
| else: | |
| symbols = DEFAULT_SYMBOLS[:20] # safe default for quick runs | |
| data_dir = Path(args.data_dir) if args.data_dir else None | |
| dataset = build_dataset(symbols, bars=args.bars, data_dir=data_dir) | |
| X = dataset[FEATURE_COLUMNS].values.astype(np.float64) | |
| y = dataset["label"].values.astype(np.int32) | |
| timestamps = dataset["_timestamp"].values | |
| logger.info(f"Feature matrix: {X.shape} | Positive rate: {y.mean():.4f}") | |
| # Walk-forward validation | |
| logger.info("Running walk-forward validation...") | |
| wf_results = run_walk_forward(X, y, timestamps=timestamps, params=LGBM_PARAMS) | |
| summary = summarize_walk_forward(wf_results) | |
| logger.info("\n=== WALK-FORWARD SUMMARY ===") | |
| logger.info(f" Folds: {summary['n_folds']}") | |
| logger.info(f" Mean threshold: {summary['mean_threshold']:.4f} Β± {summary['std_threshold']:.4f}") | |
| logger.info(f" Mean expectancy: {summary['mean_expectancy']}") | |
| logger.info(f" Mean sharpe: {summary['mean_sharpe']}") | |
| logger.info(f" Mean precision: {summary['mean_precision']}") | |
| if summary.get("mean_expectancy") is not None and summary["mean_expectancy"] < 0: | |
| logger.warning("Negative mean expectancy! Model may not generalize. Check data quality.") | |
| # Choose final threshold: mean of walk-forward optimal thresholds | |
| final_threshold = summary["mean_threshold"] | |
| logger.info(f"\nFinal threshold: {final_threshold:.4f}") | |
| # Feature importance report | |
| imp_arr = np.array(summary["avg_feature_importance"]) | |
| imp_pairs = sorted(zip(FEATURE_COLUMNS, imp_arr), key=lambda x: x[1], reverse=True) | |
| logger.info("\n=== TOP 15 FEATURES BY IMPORTANCE ===") | |
| for feat, imp in imp_pairs[:15]: | |
| bar = "β" * int(imp / imp_arr.max() * 30) if imp_arr.max() > 0 else "" | |
| logger.info(f" {feat:<28} {imp:>8.2f} {bar}") | |
| # Fit final model on all data | |
| logger.info("\nFitting final model on full dataset...") | |
| final_backend = fit_final_model(X, y, LGBM_PARAMS, val_frac=0.15) | |
| # Save everything | |
| save_artifacts(final_backend, final_threshold, summary, dataset) | |
| logger.info("\nβ Training complete.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train OKX trade probability filter") | |
| parser.add_argument("--symbols", nargs="+", default=None, help="Symbol list, e.g. BTC-USDT ETH-USDT") | |
| parser.add_argument("--use-defaults", action="store_true", help="Use all DEFAULT_SYMBOLS from config") | |
| parser.add_argument("--bars", type=int, default=300, help="OHLCV bars to fetch per symbol") | |
| parser.add_argument("--data-dir", type=str, default=None, help="Directory of pre-saved CSV files") | |
| args = parser.parse_args() | |
| main(args) | |