File size: 2,927 Bytes
3b452a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | from __future__ import annotations
import json
import os
import random
from pathlib import Path
from typing import Dict, List, Optional
from .scenario_types import EmailScenario
class DatasetLoader:
def __init__(self, data_dir: Optional[Path] = None) -> None:
env_data_dir = os.getenv("IDPI_SCENARIO_DIR")
self._data_dir = data_dir or (Path(env_data_dir) if env_data_dir else Path(__file__).resolve().parent / "data" / "scenarios")
self._cache: Dict[str, List[EmailScenario]] = {}
def load_scenarios(self, difficulty: str) -> List[EmailScenario]:
normalized = difficulty.lower()
if normalized in self._cache:
return self._cache[normalized]
scenarios: List[EmailScenario] = []
seen_ids = set()
scenario_files = self._scenario_files(normalized)
if not scenario_files:
raise FileNotFoundError(f"Missing scenario file for difficulty '{normalized}' in {self._data_dir}")
for scenario_path in scenario_files:
with scenario_path.open("r", encoding="utf-8-sig") as handle:
for line_number, raw_line in enumerate(handle, start=1):
line = raw_line.strip()
if not line:
continue
payload = json.loads(line)
try:
scenario = EmailScenario.from_dict(payload)
except KeyError as exc:
raise ValueError(
f"Scenario file {scenario_path} is missing field {exc} on line {line_number}"
) from exc
if scenario.scenario_id in seen_ids:
continue
seen_ids.add(scenario.scenario_id)
scenarios.append(scenario)
if not scenarios:
raise ValueError(f"No scenarios found for difficulty '{normalized}' in {self._data_dir}")
self._cache[normalized] = scenarios
return scenarios
def get_scenario(self, difficulty: str, seed: Optional[int] = None) -> EmailScenario:
scenarios = self.load_scenarios(difficulty)
chooser = random.Random(seed) if seed is not None else random
return chooser.choice(scenarios)
def get_scenario_by_id(self, difficulty: str, scenario_id: str) -> EmailScenario:
for scenario in self.load_scenarios(difficulty):
if scenario.scenario_id == scenario_id:
return scenario
raise KeyError(f"Unknown scenario_id '{scenario_id}' for difficulty '{difficulty}'")
def _scenario_files(self, difficulty: str) -> List[Path]:
files = [self._data_dir / f"{difficulty}.jsonl"]
gold_path = self._data_dir / "gold" / f"{difficulty}.jsonl"
if gold_path.exists():
files.append(gold_path)
return [path for path in files if path.exists()]
|