Spaces:
Build error
Build error
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import yfinance as yf | |
| import matplotlib.pyplot as plt | |
| from pandas.tseries.frequencies import to_offset | |
| from gluonts.dataset.common import ListDataset | |
| # --- Moirai 2.0 via Uni2TS --- | |
| # Make sure your requirements install Uni2TS from GitHub: | |
| # git+https://github.com/SalesforceAIResearch/uni2ts.git | |
| try: | |
| from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module | |
| except Exception as e: | |
| raise ImportError( | |
| "Moirai 2.0 not found in your Uni2TS install.\n" | |
| "Ensure requirements.txt includes:\n" | |
| " git+https://github.com/SalesforceAIResearch/uni2ts.git\n" | |
| f"Original error: {e}" | |
| ) | |
| MODEL_ID = "Salesforce/moirai-2.0-R-small" | |
| DEFAULT_CONTEXT = 1680 # from Moirai examples, but we clamp to series length | |
| # ---------------------------- | |
| # Model loader (single instance) | |
| # ---------------------------- | |
| _MODULE = None | |
| def load_module(): | |
| global _MODULE | |
| if _MODULE is None: | |
| _MODULE = Moirai2Module.from_pretrained(MODEL_ID) | |
| return _MODULE | |
| # ---------------------------- | |
| # Shared forecasting core | |
| # ---------------------------- | |
| def _future_index(last_idx: pd.Timestamp, freq: str, horizon: int) -> pd.DatetimeIndex: | |
| off = to_offset(freq) | |
| start = last_idx + off | |
| return pd.date_range(start=start, periods=horizon, freq=freq) | |
| def _run_forecast_on_series( | |
| y: pd.Series, | |
| freq: str, | |
| horizon: int, | |
| context_hint: int, | |
| title: str, | |
| ): | |
| if len(y) < 50: | |
| raise gr.Error("Need at least 50 points to forecast.") | |
| ctx = int(np.clip(context_hint or DEFAULT_CONTEXT, 32, len(y))) | |
| target = y.values[-ctx:].astype(np.float32) | |
| start_idx = y.index[-ctx] | |
| ds = ListDataset([{"start": start_idx, "target": target}], freq=freq) | |
| module = load_module() | |
| model = Moirai2Forecast( | |
| module=module, | |
| prediction_length=int(horizon), | |
| context_length=ctx, | |
| target_dim=1, | |
| feat_dynamic_real_dim=0, | |
| past_feat_dynamic_real_dim=0, | |
| ) | |
| predictor = model.create_predictor(batch_size=32) # device handled internally | |
| forecast = next(iter(predictor.predict(ds))) | |
| if hasattr(forecast, "mean"): | |
| yhat = np.asarray(forecast.mean) | |
| elif hasattr(forecast, "quantile"): | |
| yhat = np.asarray(forecast.quantile(0.5)) | |
| elif hasattr(forecast, "samples"): | |
| yhat = np.asarray(forecast.samples).mean(axis=0) | |
| else: | |
| yhat = np.asarray(forecast) | |
| yhat = np.asarray(yhat).ravel()[:horizon] | |
| future_idx = _future_index(y.index[-1], freq, horizon) | |
| pred = pd.Series(yhat, index=future_idx, name="prediction") | |
| # Plot | |
| fig = plt.figure(figsize=(10, 5)) | |
| plt.plot(y.index, y.values, label="history") | |
| plt.plot(pred.index, pred.values, label="forecast") | |
| plt.title(title) | |
| plt.xlabel("Time"); plt.ylabel("Value"); plt.legend(); plt.tight_layout() | |
| out_df = pd.DataFrame({"date": pred.index, "prediction": pred.values}) | |
| return fig, out_df | |
| # ---------------------------- | |
| # Ticker helpers | |
| # ---------------------------- | |
| def fetch_series(ticker: str, years: int) -> pd.Series: | |
| """Fetch daily close prices and align to business-day frequency.""" | |
| data = yf.download( | |
| ticker, | |
| period=f"{years}y", | |
| interval="1d", | |
| auto_adjust=True, | |
| progress=False, | |
| threads=True, | |
| ) | |
| if data is None or data.empty: | |
| raise gr.Error(f"No price data found for '{ticker}'.") | |
| col = "Close" if "Close" in data.columns else ("Adj Close" if "Adj Close" in data.columns else None) | |
| if col is None: | |
| raise gr.Error(f"Unexpected columns from yfinance: {list(data.columns)}") | |
| if isinstance(data.columns, pd.MultiIndex): | |
| if ticker in data[col].columns: | |
| s = data[col][ticker] | |
| else: | |
| s = data[col].iloc[:, 0] | |
| else: | |
| s = data[col] | |
| y = s.copy() | |
| y.name = ticker | |
| y.index = pd.DatetimeIndex(y.index).tz_localize(None) | |
| # Business-day index; forward-fill holidays | |
| bidx = pd.bdate_range(y.index.min(), y.index.max()) | |
| y = y.reindex(bidx).ffill() | |
| if y.isna().all(): | |
| raise gr.Error(f"Only missing values for '{ticker}'.") | |
| return y | |
| def forecast_ticker(ticker: str, horizon: int, lookback_years: int, context_hint: int): | |
| ticker = (ticker or "").strip().upper() | |
| if not ticker: | |
| raise gr.Error("Please enter a ticker symbol (e.g., AAPL).") | |
| if horizon < 1: | |
| raise gr.Error("Forecast horizon must be at least 1.") | |
| y = fetch_series(ticker, lookback_years) | |
| return _run_forecast_on_series(y, "B", horizon, context_hint, f"{ticker} — forecast (Moirai 2.0 R-small)") | |
| # ---------------------------- | |
| # CSV helpers | |
| # ---------------------------- | |
| def _read_csv_columns(file_path: str) -> pd.DataFrame: | |
| try: | |
| df = pd.read_csv(file_path) | |
| except Exception: | |
| df = pd.read_csv(file_path, sep=None, engine="python") | |
| return df | |
| def _coerce_numeric_series(s: pd.Series) -> pd.Series: | |
| s = pd.to_numeric(s, errors="coerce") | |
| return s.dropna().astype(np.float32) | |
| def build_series_from_csv(file, value_col: str, date_col: str, freq_choice: str): | |
| """ | |
| Returns (series y with DateTimeIndex, freq string). | |
| - If date_col is provided: parse dates and infer/align frequency. | |
| - If NO date_col: create a synthetic date index using freq_choice (default to 'D' if auto/blank). | |
| """ | |
| if file is None: | |
| raise gr.Error("Please upload a CSV file.") | |
| # Gradio file object handling (v4/v5) | |
| path = getattr(file, "name", None) or getattr(file, "path", None) or (file if isinstance(file, str) else None) | |
| if path is None: | |
| raise gr.Error("Could not read the uploaded file path.") | |
| df = _read_csv_columns(path) | |
| if df.empty: | |
| raise gr.Error("Uploaded file is empty.") | |
| # Value column selection | |
| value_col = (value_col or "").strip() | |
| if value_col: | |
| if value_col not in df.columns: | |
| raise gr.Error(f"Value column '{value_col}' not found. Available: {list(df.columns)}") | |
| vals = _coerce_numeric_series(df[value_col]) | |
| else: | |
| numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] | |
| if numeric_cols: | |
| vals = _coerce_numeric_series(df[numeric_cols[0]]) | |
| else: | |
| vals = _coerce_numeric_series(df.iloc[:, 0]) | |
| if vals.empty or len(vals) < 10: | |
| raise gr.Error("Not enough numeric values after parsing (need at least 10).") | |
| date_col = (date_col or "").strip() | |
| freq_choice_norm = (freq_choice or "").strip().upper() | |
| if date_col: | |
| if date_col not in df.columns: | |
| raise gr.Error(f"Date column '{date_col}' not found. Available: {list(df.columns)}") | |
| dt = pd.to_datetime(df[date_col], errors="coerce") | |
| mask = dt.notna() & vals.notna() | |
| dt = pd.DatetimeIndex(dt[mask]).tz_localize(None) | |
| vals = vals[mask] | |
| if len(vals) < 10: | |
| raise gr.Error("Too few valid rows after parsing date/value columns.") | |
| # Sort & dedupe index BEFORE inferring/aligning freq | |
| order = np.argsort(dt.values) | |
| dt = dt[order] | |
| vals = vals.iloc[order].reset_index(drop=True) | |
| y = pd.Series(vals.values, index=dt, name=value_col or "value").copy() | |
| y = y[~y.index.duplicated(keep="last")].sort_index() | |
| # Choose frequency | |
| if freq_choice_norm and freq_choice_norm != "AUTO": | |
| freq = freq_choice_norm | |
| else: | |
| inferred = pd.infer_freq(y.index) | |
| if inferred: | |
| freq = inferred | |
| else: | |
| weekday_ratio = (y.index.dayofweek < 5).mean() | |
| freq = "B" if weekday_ratio > 0.95 else "D" | |
| # Align to chosen frequency | |
| y = y.asfreq(freq, method="ffill") | |
| else: | |
| # No date column: build synthetic index | |
| freq = "D" if (not freq_choice_norm or freq_choice_norm == "AUTO") else freq_choice_norm | |
| idx = pd.date_range(start="2000-01-01", periods=len(vals), freq=freq) | |
| y = pd.Series(vals.values, index=idx, name=value_col or "value").copy() | |
| if y.isna().all(): | |
| raise gr.Error("Series is all-NaN after processing.") | |
| return y, freq | |
| def forecast_csv(file, value_col: str, date_col: str, freq_choice: str, horizon: int, context_hint: int): | |
| y, freq = build_series_from_csv(file, value_col, date_col, freq_choice) | |
| return _run_forecast_on_series(y, freq, horizon, context_hint, f"Uploaded series — forecast (freq={freq})") | |
| # ---------------------------- | |
| # UI | |
| # ---------------------------- | |
| with gr.Blocks(title="Moirai 2.0 — Time Series Forecast (Research)") as demo: | |
| gr.Markdown( | |
| """ | |
| # Moirai 2.0 — Time Series Forecast (Research) | |
| Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ticker *or* a generic CSV time series. | |
| > **Important**: Research/educational use only. Not investment advice. Model license: **CC-BY-NC-4.0 (non-commercial)**. | |
| """ | |
| ) | |
| with gr.Tab("By Ticker"): | |
| with gr.Row(): | |
| ticker = gr.Textbox(label="Ticker", value="AAPL", placeholder="e.g., AAPL, MSFT, TSLA") | |
| horizon_t = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (steps)") | |
| with gr.Row(): | |
| lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)") | |
| ctx_t = gr.Slider(64, 5000, value=1680, step=16, label="Context length") | |
| run_t = gr.Button("Run forecast", variant="primary") | |
| plot_t = gr.Plot(label="History + Forecast") | |
| table_t = gr.Dataframe(label="Forecast table", interactive=False) | |
| run_t.click(forecast_ticker, inputs=[ticker, horizon_t, lookback, ctx_t], outputs=[plot_t, table_t]) | |
| with gr.Tab("Upload CSV"): | |
| gr.Markdown( | |
| "Upload a CSV with either (1) a **date/time column** and a **value column**, " | |
| "or (2) just a numeric value column (then choose a frequency, or leave **auto** to default to **D**)." | |
| ) | |
| with gr.Row(): | |
| file = gr.File(label="CSV file", file_types=[".csv"]) | |
| with gr.Row(): | |
| date_col = gr.Textbox(label="Date/time column (optional)", placeholder="e.g., date, timestamp") | |
| value_col = gr.Textbox(label="Value column (optional — auto-detects first numeric)", placeholder="e.g., value, close") | |
| with gr.Row(): | |
| freq_choice = gr.Dropdown( | |
| label="Frequency", | |
| value="auto", | |
| choices=["auto", "B", "D", "H", "W", "M", "MS"], | |
| info="If no date column, 'auto' defaults to D (daily)." | |
| ) | |
| with gr.Row(): | |
| horizon_u = gr.Slider(1, 500, value=60, step=1, label="Forecast horizon (steps)") | |
| ctx_u = gr.Slider(32, 5000, value=512, step=16, label="Context length") | |
| run_u = gr.Button("Run forecast on CSV", variant="primary") | |
| plot_u = gr.Plot(label="History + Forecast (CSV)") | |
| table_u = gr.Dataframe(label="Forecast table (CSV)", interactive=False) | |
| run_u.click( | |
| forecast_csv, | |
| inputs=[file, value_col, date_col, freq_choice, horizon_u, ctx_u], | |
| outputs=[plot_u, table_u], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |