File size: 2,927 Bytes
3b452a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations

import json
import os
import random
from pathlib import Path
from typing import Dict, List, Optional

from .scenario_types import EmailScenario


class DatasetLoader:
    def __init__(self, data_dir: Optional[Path] = None) -> None:
        env_data_dir = os.getenv("IDPI_SCENARIO_DIR")
        self._data_dir = data_dir or (Path(env_data_dir) if env_data_dir else Path(__file__).resolve().parent / "data" / "scenarios")
        self._cache: Dict[str, List[EmailScenario]] = {}

    def load_scenarios(self, difficulty: str) -> List[EmailScenario]:
        normalized = difficulty.lower()
        if normalized in self._cache:
            return self._cache[normalized]

        scenarios: List[EmailScenario] = []
        seen_ids = set()
        scenario_files = self._scenario_files(normalized)
        if not scenario_files:
            raise FileNotFoundError(f"Missing scenario file for difficulty '{normalized}' in {self._data_dir}")

        for scenario_path in scenario_files:
            with scenario_path.open("r", encoding="utf-8-sig") as handle:
                for line_number, raw_line in enumerate(handle, start=1):
                    line = raw_line.strip()
                    if not line:
                        continue
                    payload = json.loads(line)
                    try:
                        scenario = EmailScenario.from_dict(payload)
                    except KeyError as exc:
                        raise ValueError(
                            f"Scenario file {scenario_path} is missing field {exc} on line {line_number}"
                        ) from exc
                    if scenario.scenario_id in seen_ids:
                        continue
                    seen_ids.add(scenario.scenario_id)
                    scenarios.append(scenario)

        if not scenarios:
            raise ValueError(f"No scenarios found for difficulty '{normalized}' in {self._data_dir}")

        self._cache[normalized] = scenarios
        return scenarios

    def get_scenario(self, difficulty: str, seed: Optional[int] = None) -> EmailScenario:
        scenarios = self.load_scenarios(difficulty)
        chooser = random.Random(seed) if seed is not None else random
        return chooser.choice(scenarios)

    def get_scenario_by_id(self, difficulty: str, scenario_id: str) -> EmailScenario:
        for scenario in self.load_scenarios(difficulty):
            if scenario.scenario_id == scenario_id:
                return scenario
        raise KeyError(f"Unknown scenario_id '{scenario_id}' for difficulty '{difficulty}'")

    def _scenario_files(self, difficulty: str) -> List[Path]:
        files = [self._data_dir / f"{difficulty}.jsonl"]
        gold_path = self._data_dir / "gold" / f"{difficulty}.jsonl"
        if gold_path.exists():
            files.append(gold_path)
        return [path for path in files if path.exists()]