| """ |
| Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.) |
| |
| We save the evaluation results in a JSON file at the step-specific evaluation results directory. |
| """ |
|
|
| import json |
| import os |
| from typing import Any, Dict |
|
|
| from huggingface_hub import upload_folder |
| from lightning.fabric import Fabric |
| from lightning.fabric.utilities.rank_zero import rank_zero_only |
|
|
| from src.config import CheckpointingConfig |
| from src.training.utils.io import use_backoff |
|
|
|
|
| @rank_zero_only |
| @use_backoff() |
| def save_evaluation_results( |
| checkpointing_config: CheckpointingConfig, |
| checkpoint_step: int, |
| fabric: Fabric, |
| evaluation_results: Dict[str, Any], |
| ) -> None: |
| """Save evaluation results to disk and optionally to HuggingFace Hub. |
| |
| The evaluation results are saved in the following directory structure: |
| {checkpointing_config.runs_dir}/ |
| βββ {checkpointing_config.run_name}/ |
| βββ {checkpointing_config.eval_results_dir}/ |
| βββ step_{checkpoint_step}.json |
| |
| NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation |
| results are gathered on rank 0. |
| |
| Args: |
| checkpointing_config: Configuration object containing checkpoint settings |
| checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken) |
| fabric: Lightning Fabric instance |
| evaluation_results: Dictionary containing evaluation metrics |
| """ |
|
|
| run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name) |
| eval_results_dir = os.path.join( |
| run_dir, checkpointing_config.evaluation.eval_results_dir |
| ) |
|
|
| os.makedirs(eval_results_dir, exist_ok=True) |
|
|
| curr_eval_results_path = os.path.join( |
| eval_results_dir, f"step_{checkpoint_step}.json" |
| ) |
|
|
| |
| with open(curr_eval_results_path, "w") as f: |
| json.dump(evaluation_results, f) |
|
|
| if checkpointing_config.save_to_hf: |
| upload_folder( |
| folder_path=eval_results_dir, |
| path_in_repo=checkpointing_config.evaluation.eval_results_dir, |
| repo_id=checkpointing_config.hf_checkpoint.repo_id, |
| commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}", |
| revision=checkpointing_config.run_name, |
| token=os.getenv("HF_TOKEN"), |
| ) |
|
|