| |
| """Run an oracle-retrieval QA upper bound. |
| |
| The model receives only the gold answer sessions listed in answer_session_ids. |
| This separates answer synthesis errors from retrieval errors. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from openai import OpenAI |
|
|
| try: |
| from openai import AzureOpenAI |
| from azure.identity import ( |
| AzureCliCredential, |
| ChainedTokenCredential, |
| ManagedIdentityCredential, |
| get_bearer_token_provider, |
| ) |
|
|
| AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "") |
| if AZURE_OAUTH_SCOPE: |
| credential = get_bearer_token_provider( |
| ChainedTokenCredential( |
| AzureCliCredential(), |
| ManagedIdentityCredential(), |
| ), |
| AZURE_OAUTH_SCOPE, |
| ) |
| else: |
| credential = None |
| except ImportError: |
| AzureOpenAI = None |
| credential = None |
|
|
| from model_zoo import model_zoo |
|
|
|
|
| |
| AZURE_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT", "") |
| |
| TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "") |
|
|
|
|
| def read_json(path: str | Path) -> Any: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def read_existing_qids(path: str | Path) -> set[str]: |
| if not Path(path).exists(): |
| return set() |
| out = set() |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| if line.strip(): |
| out.add(json.loads(line)["question_id"]) |
| return out |
|
|
|
|
| def retryable_status(exc: Exception) -> int | None: |
| status = getattr(exc, "status_code", None) or getattr(exc, "http_status", None) |
| if status is not None: |
| return int(status) |
| resp = getattr(exc, "response", None) |
| if resp is not None and getattr(resp, "status_code", None) is not None: |
| return int(resp.status_code) |
| msg = str(exc).lower() |
| if "429" in msg or "rate limit" in msg: |
| return 429 |
| if "500" in msg or "internal server error" in msg: |
| return 500 |
| if "503" in msg or "api configuration unavailable" in msg: |
| return 503 |
| if "504" in msg or "gateway time-out" in msg or "gateway timeout" in msg: |
| return 504 |
| return None |
|
|
|
|
| def make_client(args, api_version: str): |
| if args.nvidia: |
| return OpenAI(api_key=os.getenv("NV_API_KEY"), base_url="https://inference-api.nvidia.com/v1") |
| if args.tritonai: |
| return OpenAI(api_key=os.getenv("TRITONAI_API_KEY"), base_url=TRITONAI_BASE_URL) |
| if args.vllm: |
| return OpenAI( |
| api_key=os.getenv("VLLM_API_KEY", "EMPTY"), |
| base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"), |
| ) |
| if args.debug: |
| return OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| if AzureOpenAI is None: |
| raise RuntimeError("AzureOpenAI is not available. Use --nvidia, --tritonai, --vllm, or --debug.") |
| return AzureOpenAI( |
| azure_endpoint=AZURE_ENDPOINT, |
| azure_ad_token_provider=credential, |
| api_version=api_version, |
| ) |
|
|
|
|
| def llm_call(client, deployment_name: str, prompt: str, use_user_role: bool, max_retries: int = 999): |
| role = "user" if use_user_role else "system" |
| for attempt in range(max_retries): |
| try: |
| return client.chat.completions.create( |
| model=deployment_name, |
| messages=[{"role": role, "content": prompt}], |
| ) |
| except Exception as exc: |
| status = retryable_status(exc) |
| if status in (403, 429, 500, 503, 504): |
| wait = min(120, 30 + attempt * 5) |
| print(f"[WARN] HTTP {status}; sleeping {wait}s then retrying", flush=True) |
| time.sleep(wait) |
| continue |
| raise |
| raise RuntimeError("unreachable") |
|
|
|
|
| def oracle_session_ids(entry: Dict[str, Any], source: str) -> List[str]: |
| if source == "answer": |
| return entry.get("answer_session_ids", []) |
| if source == "scenario": |
| return entry.get("scenario_session_ids", []) |
| if source in ("haystack_truncate", "haystack_subsample"): |
| return entry.get("haystack_session_ids", []) |
| raise ValueError(f"Unknown oracle source: {source}") |
|
|
|
|
| def select_haystack_session_ids( |
| entry: Dict[str, Any], |
| all_sessions: Dict[str, List[Dict[str, str]]], |
| source: str, |
| max_haystack_tokens: int, |
| subsample_n: int, |
| ) -> List[str]: |
| """Apply truncate/subsample policy and return the session ids to feed in. |
| |
| Token estimate uses chars/4 as a cheap heuristic (good enough for a budget). |
| """ |
| haystack = entry.get("haystack_session_ids", []) |
| if source == "haystack_truncate": |
| |
| kept: List[str] = [] |
| char_budget = max_haystack_tokens * 4 |
| used_chars = 0 |
| for sid in haystack: |
| turns = all_sessions.get(sid, []) |
| sess_chars = sum(len(t.get("content") or "") for t in turns) |
| if used_chars + sess_chars > char_budget and kept: |
| break |
| kept.append(sid) |
| used_chars += sess_chars |
| return kept |
| if source == "haystack_subsample": |
| rng = random.Random(entry["question_id"]) |
| n = min(subsample_n, len(haystack)) |
| return rng.sample(haystack, n) if n > 0 else [] |
| raise ValueError(f"select_haystack_session_ids called with non-haystack source: {source}") |
|
|
|
|
| def build_session_prompt( |
| entry: Dict[str, Any], |
| all_sessions: Dict[str, List[Dict[str, str]]], |
| source: str, |
| selected_session_ids: List[str], |
| ) -> str: |
| date_lookup = dict(zip(entry["haystack_session_ids"], entry["haystack_dates"])) |
| session_blocks = [] |
| for sid in selected_session_ids: |
| turns = all_sessions.get(sid, []) |
| session_blocks.append( |
| "Session ID: {sid}\nSession Date: {date}\nSession Content:\n{content}".format( |
| sid=sid, |
| date=date_lookup.get(sid, ""), |
| content=json.dumps( |
| [{"role": x.get("role"), "content": x.get("content")} for x in turns], |
| ensure_ascii=False, |
| ), |
| ) |
| ) |
| if session_blocks: |
| evidence = "\n\n".join(session_blocks) |
| else: |
| evidence = "(No sessions are available for this oracle source.)" |
|
|
| if source == "answer": |
| source_desc = "the gold answer-relevant chat history sessions" |
| elif source == "scenario": |
| source_desc = "all chat history sessions from the question scenario" |
| elif source == "haystack_truncate": |
| source_desc = "chat history sessions from the user's haystack (truncated to context window)" |
| elif source == "haystack_subsample": |
| source_desc = "a random subsample of chat history sessions from the user's haystack" |
| else: |
| source_desc = "chat history sessions" |
|
|
| return """I will give you {source_desc} between an assistant and a user. |
| Answer the question using only these sessions. If the provided sessions do not contain enough information to answer, say that the information is not available from the provided chat history. |
| |
| Chat history sessions: |
| |
| {evidence} |
| |
| Current Date: {question_date} |
| Question: {question} |
| Answer:""".format( |
| source_desc=source_desc, |
| evidence=evidence, |
| question_date=entry["question_date"], |
| question=entry["question"], |
| ) |
|
|
|
|
| def usage_dict(completion) -> Dict[str, int]: |
| usage = getattr(completion, "usage", None) |
| if usage is None: |
| return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} |
| return { |
| "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, |
| "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, |
| "total_tokens": getattr(usage, "total_tokens", 0) or 0, |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--in_file", required=True) |
| parser.add_argument("--out_file", required=True) |
| parser.add_argument("--model_name", required=True) |
| parser.add_argument("--all_sessions_file", default="dataset/all_sessions.json") |
| parser.add_argument("--oracle_source", |
| choices=["answer", "scenario", "haystack_truncate", "haystack_subsample"], |
| default="answer", |
| help=("answer = gold answer sessions; scenario = scenario sessions; " |
| "haystack_truncate = haystack truncated to --max_haystack_tokens; " |
| "haystack_subsample = random --subsample_n sessions from haystack")) |
| parser.add_argument("--max_haystack_tokens", type=int, default=900_000, |
| help="Token budget for haystack_truncate (chars/4 estimate); default 900K") |
| parser.add_argument("--subsample_n", type=int, default=20, |
| help="N sessions for haystack_subsample (seeded by question_id); default 20") |
| parser.add_argument("--limit", type=int, default=None) |
| parser.add_argument("--debug", action="store_true", default=False) |
| parser.add_argument("--vllm", action="store_true", default=False) |
| parser.add_argument("--tritonai", action="store_true", default=False) |
| parser.add_argument("--nvidia", action="store_true", default=False) |
| args = parser.parse_args() |
|
|
| deployment_name, api_version = model_zoo[args.model_name] |
| client = make_client(args, api_version) |
| use_user_role = args.nvidia or args.tritonai |
|
|
| entries = read_json(args.in_file) |
| if args.limit is not None: |
| entries = entries[: args.limit] |
| all_sessions = read_json(args.all_sessions_file) |
| existing = read_existing_qids(args.out_file) |
|
|
| Path(args.out_file).parent.mkdir(parents=True, exist_ok=True) |
| with open(args.out_file, "a", encoding="utf-8") as out_f: |
| for idx, entry in enumerate(entries): |
| qid = entry["question_id"] |
| if qid in existing: |
| continue |
| start = time.time() |
| if args.oracle_source in ("haystack_truncate", "haystack_subsample"): |
| selected_session_ids = select_haystack_session_ids( |
| entry, all_sessions, args.oracle_source, |
| max_haystack_tokens=args.max_haystack_tokens, |
| subsample_n=args.subsample_n, |
| ) |
| else: |
| selected_session_ids = oracle_session_ids(entry, args.oracle_source) |
| prompt = build_session_prompt( |
| entry, all_sessions, args.oracle_source, selected_session_ids |
| ) |
| completion = llm_call(client, deployment_name, prompt, use_user_role=use_user_role) |
| content = completion.choices[0].message.content if completion.choices else None |
| if content is None: |
| for _ in range(2): |
| completion = llm_call(client, deployment_name, prompt, use_user_role=use_user_role) |
| content = completion.choices[0].message.content if completion.choices else None |
| if content is not None: |
| break |
| answer = (content or "").strip() |
| usage = usage_dict(completion) |
| row = { |
| "q_idx": idx, |
| "question_id": qid, |
| "hypothesis": answer, |
| "oracle_source": args.oracle_source, |
| "oracle_session_ids": selected_session_ids, |
| "n_oracle_sessions": len(selected_session_ids), |
| "n_prompt_tok": usage["prompt_tokens"], |
| "n_completion_tok": usage["completion_tokens"], |
| "token_budget": { |
| "oracle_answer": { |
| "prompt_tokens": usage["prompt_tokens"], |
| "completion_tokens": usage["completion_tokens"], |
| "n_calls": 1, |
| }, |
| "total": { |
| "prompt_tokens": usage["prompt_tokens"], |
| "completion_tokens": usage["completion_tokens"], |
| "n_calls": 1, |
| }, |
| }, |
| "wall_time_sec": time.time() - start, |
| } |
| print(json.dumps(row, ensure_ascii=False), file=out_f, flush=True) |
| print(json.dumps({ |
| "q_idx": idx, |
| "question_id": qid, |
| "n_oracle_sessions": row["n_oracle_sessions"], |
| "wall_time_sec": round(row["wall_time_sec"], 3), |
| }), flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|