File size: 1,644 Bytes
5dd1bb4 9e64e71 5dd1bb4 9e64e71 5dd1bb4 9e64e71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | """Training utilities for GRPO-based SQLEnv experiments."""
import importlib
from .config import GRPOConfig, apply_device_overrides, find_project_root
__all__ = [
"GRPOConfig",
"apply_device_overrides",
"find_project_root",
"build_trainer",
"filter_questions_by_difficulty",
"format_observation",
"format_oom_guidance",
"get_system_prompt",
"load_model_and_tokenizer",
"load_question_prompts",
"run_training_with_metrics",
"sample_random_baseline",
"reward_correctness",
"reward_progress",
"reward_operational",
"LiveVisualizationCallback",
"SQLEnvTRL",
"sql_env_reward_func",
]
_LAZY_MAP = {
"filter_questions_by_difficulty": ".data_loading",
"load_model_and_tokenizer": ".data_loading",
"load_question_prompts": ".data_loading",
"build_trainer": ".notebook_pipeline",
"format_oom_guidance": ".notebook_pipeline",
"run_training_with_metrics": ".notebook_pipeline",
"sample_random_baseline": ".notebook_pipeline",
"format_observation": ".prompts",
"get_system_prompt": ".prompts",
"reward_correctness": ".rewards",
"reward_operational": ".rewards",
"reward_progress": ".rewards",
"LiveVisualizationCallback": ".visualization",
"SQLEnvTRL": ".trl_adapter",
"sql_env_reward_func": ".trl_adapter",
}
def __getattr__(name: str):
"""Lazy-load heavy modules on first access."""
if name not in _LAZY_MAP:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module_name = _LAZY_MAP[name]
mod = importlib.import_module(module_name, __name__)
return getattr(mod, name)
|