HyperBrickCaseOps / inference.py
modelbuilderhq's picture
Upload folder using huggingface_hub
4f129c9 verified
"""Baseline inference script for the SupportDesk OpenEnv submission."""
from __future__ import annotations
import asyncio
import json
import os
import re
from dataclasses import dataclass
from typing import Any
try:
from openai import OpenAI
except ImportError: # pragma: no cover - local fallback mode
OpenAI = None # type: ignore[assignment]
from client import SupportDeskEnv
from graders import grade_case
from models import SupportDeskAction, SupportDeskObservation
from policies import heuristic_action
from server.supportdesk_environment import SupportDeskEnvironment
from tasks import get_task, list_task_ids
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = HF_TOKEN or os.getenv("API_KEY")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
TASK_NAME = os.getenv("SUPPORTDESK_TASK_ID", "billing_refund_easy")
BENCHMARK = os.getenv("SUPPORTDESK_BENCHMARK", "supportdesk_env")
MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "300"))
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.1"))
SUBMISSION_SCORE_MIN = 0.01
SUBMISSION_SCORE_MAX = 0.99
SYSTEM_PROMPT = """You are a support operations agent solving one triage ticket.
Return exactly one JSON object with this schema:
{
"operation": "classify|request_info|draft_reply|add_internal_note|submit|wait",
"queue": string or null,
"priority": string or null,
"issue_type": string or null,
"status": string or null,
"resolution_code": string or null,
"requested_fields": [string],
"reply": string or null,
"internal_note": string or null
}
Use the policy snippets in the observation. Keep customer replies short, precise, and professional.
"""
@dataclass
class EpisodeResult:
"""Compact result used for final success logging."""
final_score: float
steps_taken: int
rewards: list[float]
def _build_client() -> OpenAI | None:
if OpenAI is None or not API_KEY:
return None
return OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
def _extract_json(text: str) -> dict[str, Any]:
try:
return json.loads(text)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if not match:
raise
return json.loads(match.group(0))
def _observation_prompt(observation: SupportDeskObservation) -> str:
kb_lines = "\n".join(
f"- {snippet.article_id}: {snippet.title}: {snippet.content}"
for snippet in observation.knowledge_base
)
history_lines = "\n".join(
f"- step {entry.step}: {entry.summary} ({entry.reward_delta:+.2f})"
for entry in observation.action_history
) or "- none"
return f"""Task: {observation.task_id} ({observation.difficulty})
Objective: {observation.objective}
Ticket subject: {observation.ticket.subject}
Ticket body: {observation.ticket.body}
Customer tier: {observation.ticket.customer_tier}
Region: {observation.ticket.region}
Affected users: {observation.ticket.affected_users}
SLA minutes remaining: {observation.current_sla_minutes_remaining}
Business impact: {observation.ticket.business_impact}
Secondary concerns: {observation.ticket.secondary_concerns}
Knowledge base:
{kb_lines}
Current case state:
- queue: {observation.case.queue}
- priority: {observation.case.priority}
- issue_type: {observation.case.issue_type}
- status: {observation.case.status}
- resolution_code: {observation.case.resolution_code}
- requested_fields: {observation.case.requested_fields}
- reply: {observation.case.reply}
- internal_note: {observation.case.internal_note}
- customer_follow_up: {observation.case.customer_follow_up.status}
Workflow stage: {observation.workflow_stage}
Required next actions: {observation.required_next_actions}
Risk flags: {observation.risk_flags}
Feedback: {observation.feedback}
Remaining steps: {observation.remaining_steps}
History:
{history_lines}
"""
def _model_action(client: OpenAI | None, observation: SupportDeskObservation) -> SupportDeskAction:
if client is None:
return heuristic_action(observation)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": _observation_prompt(observation)},
],
)
content = completion.choices[0].message.content or ""
payload = _extract_json(content)
return SupportDeskAction(**payload)
except Exception:
return heuristic_action(observation)
def _action_to_log_string(action: SupportDeskAction) -> str:
if hasattr(action, "model_dump_json"):
payload = action.model_dump(
exclude_none=True,
exclude_defaults=True,
exclude={"metadata"},
)
else:
payload = action.dict(
exclude_none=True,
exclude_defaults=True,
)
payload.pop("metadata", None)
return json.dumps(payload, separators=(",", ":"))
def _log_start(task: str) -> None:
print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True)
def _log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
error_value = error if error else "null"
print(
f"[STEP] step={step} action={action_str} reward={reward:.2f} "
f"done={str(done).lower()} error={error_value}",
flush=True,
)
def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
reward_text = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={reward_text}",
flush=True,
)
def _submission_score(score: float) -> float:
return max(SUBMISSION_SCORE_MIN, min(SUBMISSION_SCORE_MAX, score))
def _run_local_episode(task_id: str, client: OpenAI | None) -> EpisodeResult:
env = SupportDeskEnvironment(task_id=task_id)
observation = env.reset()
rewards: list[float] = []
steps_taken = 0
try:
for step in range(1, MAX_STEPS + 1):
if observation.done:
break
action = _model_action(client, observation)
action_str = _action_to_log_string(action)
try:
observation = env.step(action)
reward = observation.reward or 0.0
done = observation.done
error = None
except Exception as exc:
raise RuntimeError(str(exc)) from exc
_log_step(step, action_str, reward, done, error)
rewards.append(reward)
steps_taken = step
if done:
break
final_grade = grade_case(get_task(task_id), env.state.case)
return EpisodeResult(
final_score=final_grade.total_score,
steps_taken=steps_taken,
rewards=rewards,
)
finally:
env.close()
async def _run_docker_episode(task_id: str, client: OpenAI | None) -> EpisodeResult:
env = await SupportDeskEnv.from_docker_image(
LOCAL_IMAGE_NAME,
env_vars={"SUPPORTDESK_TASK_ID": task_id},
)
rewards: list[float] = []
steps_taken = 0
try:
result = await env.reset()
observation = result.observation
for step in range(1, MAX_STEPS + 1):
if result.done:
break
action = _model_action(client, observation)
action_str = _action_to_log_string(action)
try:
result = await env.step(action)
observation = result.observation
reward = result.reward or 0.0
done = result.done
error = None
except Exception as exc:
raise RuntimeError(str(exc)) from exc
_log_step(step, action_str, reward, done, error)
rewards.append(reward)
steps_taken = step
if done:
break
state = await env.state()
final_grade = grade_case(get_task(task_id), state.case)
return EpisodeResult(
final_score=final_grade.total_score,
steps_taken=steps_taken,
rewards=rewards,
)
finally:
await env.close()
async def main() -> None:
client = _build_client()
success = False
final_score = 0.0
steps_taken = 0
rewards: list[float] = []
_log_start(TASK_NAME)
try:
if LOCAL_IMAGE_NAME:
episode = await _run_docker_episode(TASK_NAME, client)
else:
episode = _run_local_episode(TASK_NAME, client)
final_score = _submission_score(episode.final_score)
success = final_score >= SUCCESS_SCORE_THRESHOLD
steps_taken = episode.steps_taken
rewards = episode.rewards
finally:
_log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)
if __name__ == "__main__":
asyncio.run(main())