File size: 1,644 Bytes
5dd1bb4
 
9e64e71
 
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
 
 
5dd1bb4
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Training utilities for GRPO-based SQLEnv experiments."""

import importlib

from .config import GRPOConfig, apply_device_overrides, find_project_root

__all__ = [
    "GRPOConfig",
    "apply_device_overrides",
    "find_project_root",
    "build_trainer",
    "filter_questions_by_difficulty",
    "format_observation",
    "format_oom_guidance",
    "get_system_prompt",
    "load_model_and_tokenizer",
    "load_question_prompts",
    "run_training_with_metrics",
    "sample_random_baseline",
    "reward_correctness",
    "reward_progress",
    "reward_operational",
    "LiveVisualizationCallback",
    "SQLEnvTRL",
    "sql_env_reward_func",
]

_LAZY_MAP = {
    "filter_questions_by_difficulty": ".data_loading",
    "load_model_and_tokenizer": ".data_loading",
    "load_question_prompts": ".data_loading",
    "build_trainer": ".notebook_pipeline",
    "format_oom_guidance": ".notebook_pipeline",
    "run_training_with_metrics": ".notebook_pipeline",
    "sample_random_baseline": ".notebook_pipeline",
    "format_observation": ".prompts",
    "get_system_prompt": ".prompts",
    "reward_correctness": ".rewards",
    "reward_operational": ".rewards",
    "reward_progress": ".rewards",
    "LiveVisualizationCallback": ".visualization",
    "SQLEnvTRL": ".trl_adapter",
    "sql_env_reward_func": ".trl_adapter",
}


def __getattr__(name: str):
    """Lazy-load heavy modules on first access."""
    if name not in _LAZY_MAP:
        raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
    module_name = _LAZY_MAP[name]
    mod = importlib.import_module(module_name, __name__)
    return getattr(mod, name)