Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Inference probe for the three LLM endpoints opencode runs against. | |
| For each of vLLM, OpenAI, and the HF Inference Router, fires one | |
| ``/v1/chat/completions`` request with ``logprobs=true`` and verifies: | |
| 1. HTTP status is 200. | |
| 2. The response carries either ``message.content`` or ``message.tool_calls``. | |
| 3. ``choices[0].logprobs.content`` is non-null with at least one entry. | |
| 4. The first token's ``top_logprobs`` has the requested top-k count. | |
| Endpoints are read from the sibling ``.env`` file (``envs/opencode_env/.env``). | |
| A missing config skips that endpoint instead of failing the suite. | |
| Run as pytest:: | |
| PYTHONPATH=src:envs/opencode_env uv run pytest \\ | |
| envs/opencode_env/tests/test_inference_endpoints.py -v -s | |
| Run as a standalone script (prints a summary table):: | |
| python envs/opencode_env/tests/test_inference_endpoints.py | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import httpx | |
| import pytest | |
| # --------------------------------------------------------------------------- | |
| # .env loader — no python-dotenv dep, since the package keeps deps minimal. | |
| # --------------------------------------------------------------------------- | |
| def _load_env_file(env_path: Path) -> None: | |
| """Populate ``os.environ`` from ``KEY=VALUE`` lines in ``env_path``. | |
| Existing process env vars take precedence so a shell ``export`` always | |
| wins over the ``.env`` file. Lines starting with ``#`` and blank lines | |
| are ignored. Surrounding single/double quotes on values are stripped. | |
| """ | |
| if not env_path.exists(): | |
| return | |
| for raw in env_path.read_text().splitlines(): | |
| line = raw.strip() | |
| if not line or line.startswith("#") or "=" not in line: | |
| continue | |
| key, _, value = line.partition("=") | |
| key = key.strip() | |
| value = value.strip().strip('"').strip("'") | |
| if key and key not in os.environ: | |
| os.environ[key] = value | |
| _ENV_PATH = Path(__file__).resolve().parents[1] / ".env" | |
| _load_env_file(_ENV_PATH) | |
| # --------------------------------------------------------------------------- | |
| # Endpoint specs — one per kind of LLM endpoint we exercise. | |
| # --------------------------------------------------------------------------- | |
| class EndpointSpec: | |
| """Reads one endpoint's connection info out of the environment.""" | |
| label: str | |
| base_url_env: str | |
| model_env: str | |
| api_key_env: str | |
| default_base_url: str | None = None | |
| default_model: str | None = None | |
| default_api_key: str = "" | |
| def resolve(self) -> "EndpointConfig | None": | |
| base_url = os.environ.get(self.base_url_env) or self.default_base_url or "" | |
| model = os.environ.get(self.model_env) or self.default_model or "" | |
| api_key = os.environ.get(self.api_key_env) or self.default_api_key | |
| if not base_url or not model or not api_key: | |
| return None | |
| return EndpointConfig( | |
| label=self.label, base_url=base_url, model=model, api_key=api_key | |
| ) | |
| class EndpointConfig: | |
| label: str | |
| base_url: str | |
| model: str | |
| api_key: str | |
| def _resolve_chat_completions_url(base_url: str) -> str: | |
| """Build the fully-qualified ``/v1/chat/completions`` URL. | |
| Mirrors :func:`opencode_env.interception._resolve_upstream_url`: if the | |
| base already ends in ``/v1`` (or includes a path), only ``/chat/completions`` | |
| is appended; otherwise the full ``/v1/chat/completions`` path is used. | |
| """ | |
| base = base_url.rstrip("/") | |
| if base.endswith("/v1"): | |
| return f"{base}/chat/completions" | |
| return f"{base}/v1/chat/completions" | |
| # Defaults below mirror what the OpenCode harness primitive uses by default. | |
| # .env values override anything specified here. | |
| ENDPOINT_SPECS: list[EndpointSpec] = [ | |
| EndpointSpec( | |
| label="vllm", | |
| base_url_env="VLLM_URL", | |
| model_env="VLLM_MODEL", | |
| api_key_env="VLLM_API_KEY", | |
| default_api_key="intercepted", | |
| default_model="Qwen/Qwen3.5-4B", | |
| ), | |
| EndpointSpec( | |
| label="openai", | |
| base_url_env="OPENAI_BASE_URL", | |
| model_env="OPENAI_MODEL", | |
| api_key_env="OPENAI_API_KEY", | |
| default_base_url="https://api.openai.com/v1", | |
| default_model="gpt-4o-mini", | |
| ), | |
| EndpointSpec( | |
| label="hf_router", | |
| base_url_env="HF_ROUTER_BASE_URL", | |
| model_env="HF_ROUTER_MODEL", | |
| api_key_env="HF_ROUTER_API_KEY", | |
| default_base_url="https://router.huggingface.co/v1", | |
| default_model="Qwen/Qwen3-4B-Instruct-2507:nscale", | |
| ), | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Probe — one HTTP round trip per call; pure function, no side effects. | |
| # --------------------------------------------------------------------------- | |
| class ProbeResult: | |
| label: str | |
| base_url: str | |
| model: str | |
| status: int | |
| ok: bool | |
| completion_text: str = "" | |
| has_tool_calls: bool = False | |
| has_logprobs: bool = False | |
| top_logprobs_n: int = 0 | |
| first_token: str = "" | |
| first_logprob: float | None = None | |
| latency_s: float = 0.0 | |
| error: str = "" | |
| raw_response: dict[str, Any] = field(default_factory=dict) | |
| def probe( | |
| cfg: EndpointConfig, | |
| *, | |
| top_logprobs: int = 5, | |
| max_tokens: int = 16, | |
| timeout_s: float = 90.0, | |
| ) -> ProbeResult: | |
| """Send one chat-completions request and return what the endpoint did. | |
| Never raises. Network / 4xx / 5xx errors land in ``ProbeResult.error`` so | |
| the caller can render a table without try/except scaffolding. | |
| """ | |
| import time | |
| url = _resolve_chat_completions_url(cfg.base_url) | |
| body: dict[str, Any] = { | |
| "model": cfg.model, | |
| "messages": [{"role": "user", "content": "Reply with a single word: hi"}], | |
| "max_tokens": max_tokens, | |
| "logprobs": True, | |
| "top_logprobs": top_logprobs, | |
| "temperature": 0, | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {cfg.api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| start = time.time() | |
| try: | |
| r = httpx.post(url, json=body, headers=headers, timeout=timeout_s) | |
| except Exception as exc: # noqa: BLE001 | |
| return ProbeResult( | |
| label=cfg.label, | |
| base_url=cfg.base_url, | |
| model=cfg.model, | |
| status=0, | |
| ok=False, | |
| error=f"{type(exc).__name__}: {exc}", | |
| latency_s=time.time() - start, | |
| ) | |
| latency = time.time() - start | |
| if r.status_code != 200: | |
| return ProbeResult( | |
| label=cfg.label, | |
| base_url=cfg.base_url, | |
| model=cfg.model, | |
| status=r.status_code, | |
| ok=False, | |
| error=r.text[:600], | |
| latency_s=latency, | |
| ) | |
| try: | |
| data = r.json() | |
| except Exception as exc: # noqa: BLE001 | |
| return ProbeResult( | |
| label=cfg.label, | |
| base_url=cfg.base_url, | |
| model=cfg.model, | |
| status=r.status_code, | |
| ok=False, | |
| error=f"non-JSON body: {exc}", | |
| latency_s=latency, | |
| ) | |
| choice = (data.get("choices") or [{}])[0] | |
| msg = choice.get("message") or {} | |
| completion_text = msg.get("content") or "" | |
| has_tool_calls = bool(msg.get("tool_calls")) | |
| lp = choice.get("logprobs") | |
| content_lp = lp.get("content") if isinstance(lp, dict) else None | |
| has_logprobs = bool(content_lp) | |
| first_token = "" | |
| first_logprob: float | None = None | |
| top_n = 0 | |
| if has_logprobs and content_lp: | |
| first = content_lp[0] | |
| first_token = str(first.get("token", "")) | |
| lp_val = first.get("logprob") | |
| if lp_val is not None: | |
| first_logprob = float(lp_val) | |
| top_n = len(first.get("top_logprobs") or []) | |
| return ProbeResult( | |
| label=cfg.label, | |
| base_url=cfg.base_url, | |
| model=cfg.model, | |
| status=r.status_code, | |
| ok=True, | |
| completion_text=completion_text, | |
| has_tool_calls=has_tool_calls, | |
| has_logprobs=has_logprobs, | |
| top_logprobs_n=top_n, | |
| first_token=first_token, | |
| first_logprob=first_logprob, | |
| latency_s=latency, | |
| raw_response=data, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # pytest entrypoints — one parametrized test per endpoint. | |
| # --------------------------------------------------------------------------- | |
| def test_endpoint_responds(spec: EndpointSpec) -> None: | |
| """Endpoint accepts a chat-completions call and returns a 2xx body.""" | |
| cfg = spec.resolve() | |
| if cfg is None: | |
| pytest.skip( | |
| f"{spec.label} not configured (set {spec.base_url_env} / " | |
| f"{spec.model_env} / {spec.api_key_env} in .env)" | |
| ) | |
| result = probe(cfg) | |
| assert result.ok, f"{cfg.label}: HTTP {result.status} — {result.error}" | |
| # ``logprobs.content`` populated implies the model generated at least one | |
| # token (either visible content, a tool-call argument, or a reasoning | |
| # token for Qwen3-thinking variants). That is the signal we want — empty | |
| # completion + empty tool_calls is fine when reasoning tokens are present. | |
| assert result.has_logprobs or result.completion_text or result.has_tool_calls, ( | |
| f"{cfg.label}: model produced no output at all. " | |
| f"Response: {str(result.raw_response)[:500]}" | |
| ) | |
| def test_endpoint_returns_logprobs(spec: EndpointSpec) -> None: | |
| """Endpoint honors ``logprobs=true`` and returns per-token logprobs. | |
| Failing this test means the endpoint silently drops logprobs (HF Router | |
| providers like Novita / Hyperbolic / Featherless behave this way) — the | |
| transparent proxy has nothing to capture and Mode B GRPO will train on | |
| empty per-token logps. | |
| """ | |
| cfg = spec.resolve() | |
| if cfg is None: | |
| pytest.skip( | |
| f"{spec.label} not configured (set {spec.base_url_env} / " | |
| f"{spec.model_env} / {spec.api_key_env} in .env)" | |
| ) | |
| result = probe(cfg) | |
| assert result.ok, f"{cfg.label}: HTTP {result.status} — {result.error}" | |
| assert result.has_logprobs, ( | |
| f"{cfg.label}: endpoint returned 200 but logprobs.content is null. " | |
| f"This provider does not support logprobs. Pick a different provider " | |
| f"(together / nscale / scaleway) or run opencode in mode='black_box'." | |
| ) | |
| assert result.top_logprobs_n >= 1, ( | |
| f"{cfg.label}: top_logprobs has {result.top_logprobs_n} entries, " | |
| f"expected >= 1" | |
| ) | |
| assert result.first_logprob is not None, ( | |
| f"{cfg.label}: first token has no logprob value" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Standalone runner — prints a summary table. | |
| # --------------------------------------------------------------------------- | |
| def _format_summary(results: list[ProbeResult], skipped: list[str]) -> str: | |
| rows: list[str] = [] | |
| rows.append("-" * 96) | |
| rows.append( | |
| f"{'endpoint':<10} {'status':<7} {'logprobs':<14} {'top-n':<6} " | |
| f"{'first-token':<14} {'first-logp':<11} {'latency':<8} notes" | |
| ) | |
| rows.append("-" * 96) | |
| for r in results: | |
| if r.status == 0: | |
| status_str = "ERR" | |
| else: | |
| status_str = str(r.status) | |
| if not r.ok: | |
| lp_str = "n/a" | |
| elif r.has_logprobs: | |
| lp_str = f"yes ({r.top_logprobs_n})" | |
| else: | |
| lp_str = "DROPPED" | |
| first_tok_str = repr(r.first_token) if r.first_token else "-" | |
| first_lp_str = ( | |
| f"{r.first_logprob:+.3f}" if r.first_logprob is not None else "-" | |
| ) | |
| latency_str = f"{r.latency_s:.2f}s" | |
| notes = "" | |
| if not r.ok: | |
| notes = r.error[:50].replace("\n", " ") | |
| elif not r.has_logprobs: | |
| notes = "silent logprob drop" | |
| rows.append( | |
| f"{r.label:<10} {status_str:<7} {lp_str:<14} " | |
| f"{r.top_logprobs_n:<6} {first_tok_str:<14} " | |
| f"{first_lp_str:<11} {latency_str:<8} {notes}" | |
| ) | |
| rows.append("-" * 96) | |
| if skipped: | |
| rows.append("") | |
| rows.append("Skipped (not configured in .env):") | |
| for s in skipped: | |
| rows.append(f" - {s}") | |
| return "\n".join(rows) | |
| def main() -> int: | |
| print(f"Loading env from {_ENV_PATH}\n") | |
| results: list[ProbeResult] = [] | |
| skipped: list[str] = [] | |
| for spec in ENDPOINT_SPECS: | |
| cfg = spec.resolve() | |
| if cfg is None: | |
| skipped.append( | |
| f"{spec.label} (set {spec.base_url_env} / {spec.model_env} / " | |
| f"{spec.api_key_env})" | |
| ) | |
| continue | |
| print(f"-> probing {cfg.label}: {cfg.base_url} model={cfg.model}") | |
| r = probe(cfg) | |
| results.append(r) | |
| if not r.ok: | |
| print(f" HTTP {r.status}: {r.error[:200]}") | |
| else: | |
| print( | |
| f" HTTP {r.status} logprobs={r.has_logprobs} " | |
| f"top_n={r.top_logprobs_n} " | |
| f"content={r.completion_text!r:.60}" | |
| ) | |
| print() | |
| print(_format_summary(results, skipped)) | |
| if not results: | |
| print("\nNo endpoints configured. Fill in .env and re-run.") | |
| return 2 | |
| bad = [r for r in results if not r.ok or not r.has_logprobs] | |
| if bad: | |
| print(f"\n{len(bad)}/{len(results)} endpoint(s) failed or lack logprobs.") | |
| return 1 | |
| print(f"\nAll {len(results)} configured endpoint(s) passed.") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |