File size: 1,720 Bytes
ab97cf6
 
 
4888b97
ab97cf6
4888b97
ab97cf6
4888b97
 
 
ab97cf6
4888b97
 
 
 
ab97cf6
4888b97
ab97cf6
 
4ddf12d
 
ab97cf6
 
4888b97
4ddf12d
 
 
 
 
 
 
c372dd3
4888b97
 
c372dd3
4ddf12d
 
 
 
 
 
ab97cf6
4888b97
 
 
c372dd3
 
4888b97
 
 
 
 
 
 
ab97cf6
 
 
4888b97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import annotations

import os
import sys
from pathlib import Path
from typing import Dict

REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

try:
    from openai import OpenAI
except ModuleNotFoundError:  # pragma: no cover - dependency-gated in local shells
    OpenAI = None

from baseline.run_baseline import main as baseline_main


DEFAULT_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
DEFAULT_MODEL_NAME = "gemini-2.5-flash"


def validate_submission_env() -> Dict[str, str]:
    api_base_url = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
    model_name = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
    hf_token = os.getenv("HF_TOKEN")
    local_image_name = os.getenv("LOCAL_IMAGE_NAME")

    if not hf_token:
        raise RuntimeError("Missing required inference environment variable: HF_TOKEN.")

    if OpenAI is None:
        raise RuntimeError("The openai package is required for submission inference.")

    config = {
        "API_BASE_URL": api_base_url,
        "MODEL_NAME": model_name,
        "HF_TOKEN": hf_token,
        "LOCAL_IMAGE_NAME": local_image_name,
    }

    # Fail fast on invalid client wiring before the benchmark starts.
    OpenAI(base_url=config["API_BASE_URL"], api_key=config["HF_TOKEN"])
    return config


def main() -> None:
    validate_submission_env()
    os.environ["BASELINE_AGENT"] = "openai"
    os.environ["BASELINE_OUTPUT_MODE"] = "submission"
    os.environ.setdefault("EPISODES_PER_DIFFICULTY", "1")
    os.environ.setdefault("BASELINE_RESULTS_PATH", "baseline/submission_results.json")
    baseline_main()


if __name__ == "__main__":
    main()