| |
| import argparse |
| import os |
| import re |
| import subprocess |
| from datetime import datetime, timedelta |
|
|
| |
| |
| TEMPLATE = """#!/bin/bash |
| #SBATCH --job-name={job_name} |
| #SBATCH --account=ojz@h100 |
| #SBATCH --constraint=h100 |
| #SBATCH --qos={qos} |
| #SBATCH --time={time} |
| #SBATCH --nodes=1 |
| #SBATCH --ntasks-per-node={gpus} |
| #SBATCH --gres=gpu:{gpus} |
| #SBATCH --cpus-per-task=24 |
| #SBATCH --hint=nomultithread |
| #SBATCH --output=logs/slurm/%x-%j.log |
| #SBATCH --error=logs/slurm/%x-%j.log |
| |
| set -euxo pipefail |
| |
| export MPLBACKEND=Agg |
| |
| if ! command -v module >/dev/null 2>&1; then |
| source /etc/profile.d/modules.sh || true |
| fi |
| |
| module load arch/h100 |
| module load {ffmpeg_module} |
| |
| FFMPEG_BIN=$(command -v ffmpeg || true) |
| if [ -n "$FFMPEG_BIN" ]; then |
| FFMPEG_ROOT=$(dirname "$(dirname "$FFMPEG_BIN")") |
| export LD_LIBRARY_PATH="${{FFMPEG_ROOT}}/lib:${{LD_LIBRARY_PATH}}" |
| fi |
| |
| if [ -n "${{EBROOTFFMPEG:-}}" ]; then |
| export LD_LIBRARY_PATH="${{EBROOTFFMPEG}}/lib:${{LD_LIBRARY_PATH}}" |
| fi |
| |
| cd {workdir} |
| |
| export PYTHONUNBUFFERED=1 |
| export HYDRA_FULL_ERROR=1 |
| export TMPDIR=$SCRATCH |
| export TEMP=$SCRATCH |
| export TMP=$SCRATCH |
| export PROJECT_ROOT={workdir} |
| |
| # Ensure log directory exists |
| mkdir -p logs/slurm |
| |
| source .venv/bin/activate |
| |
| # Configuration Info |
| # Experiment: {experiment} |
| # GPUs: {gpus} |
| # Strategy: {strategy} |
| # WandB Name: {wandb_name} |
| # Trainer Max Time: {max_time} |
| # FFmpeg module: {ffmpeg_module} |
| |
| echo "Starting job {job_name} on $(hostname)" |
| echo "Experiment: {experiment}" |
| echo "FFmpeg binary: $(command -v ffmpeg || echo 'not found')" |
| ffmpeg -version | head -n 1 |
| |
| srun .venv/bin/python -u -O src/train.py \\ |
| experiment={experiment} \\ |
| ++trainer.devices={gpus} \\ |
| ++trainer.strategy={strategy} \\ |
| ++trainer.max_time="{max_time}" \\ |
| ++logger.wandb.name="{wandb_name}" \\ |
| {extra_args} |
| """ |
|
|
|
|
| def parse_slurm_time(time_str): |
| """Parses Slurm time string into a timedelta object. |
| Formats: "MM", "MM:SS", "HH:MM:SS", "D-HH", "D-HH:MM", "D-HH:MM:SS" |
| """ |
| days = 0 |
| if "-" in time_str: |
| days_str, time_str = time_str.split("-") |
| days = int(days_str) |
|
|
| parts = list(map(int, time_str.split(":"))) |
|
|
| if len(parts) == 1: |
| minutes = parts[0] |
| hours = 0 |
| seconds = 0 |
| elif len(parts) == 2: |
| minutes, seconds = parts |
| hours = 0 |
| elif len(parts) == 3: |
| hours, minutes, seconds = parts |
| else: |
| raise ValueError(f"Invalid time format: {time_str}") |
|
|
| return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) |
|
|
|
|
| def format_timedelta(td): |
| """Formats timedelta back to DD:HH:MM:SS string for Lightning""" |
| total_seconds = int(td.total_seconds()) |
| days, remainder = divmod(total_seconds, 86400) |
| hours, remainder = divmod(remainder, 3600) |
| minutes, seconds = divmod(remainder, 60) |
| return f"{days:02}:{hours:02}:{minutes:02}:{seconds:02}" |
|
|
|
|
| def parse_config_value(content, pattern): |
| match = re.search(pattern, content) |
| return match.group(1).strip() if match else None |
|
|
|
|
| def select_qos(time_limit: timedelta) -> str: |
| two_hours = timedelta(hours=2) |
| twenty_hours = timedelta(hours=20) |
| one_hundred_hours = timedelta(hours=100) |
|
|
| if time_limit <= two_hours: |
| return "qos_gpu_h100-dev" |
| if time_limit <= twenty_hours: |
| return "qos_gpu_h100-t3" |
| if time_limit <= one_hundred_hours: |
| return "qos_gpu_h100-t4" |
|
|
| raise ValueError( |
| "Requested time exceeds maximum supported QoS window (100h). " |
| "Please request 100:00:00 or less." |
| ) |
|
|
|
|
| def format_steps(steps_str): |
| if not steps_str or not steps_str.isdigit(): |
| return steps_str |
|
|
| steps = int(steps_str) |
| if steps >= 1000000: |
| return f"{steps // 1000000}m" |
| if steps >= 1000: |
| return f"{steps // 1000}k" |
| return str(steps) |
|
|
|
|
| def generate_wandb_name(config_path, num_gpus, suffix=None): |
| try: |
| with open(config_path, "r") as f: |
| content = f.read() |
| except FileNotFoundError: |
| print( |
| f"Warning: Config file not found at {config_path}. Cannot auto-generate name." |
| ) |
| return "experiment" |
|
|
| |
| model = parse_config_value(content, r"override /model:\s*(\S+)") |
| dataset = parse_config_value(content, r"override /data:\s*(\S+)") |
| batch_size = parse_config_value(content, r"batch_size:\s*(\d+)") |
| max_steps = parse_config_value(content, r"max_steps:\s*(\d+)") |
|
|
| |
| parts = [] |
|
|
| if model: |
| parts.append(model) |
| if dataset: |
| parts.append(dataset) |
|
|
| if max_steps: |
| parts.append(format_steps(max_steps)) |
|
|
| if batch_size: |
| parts.append(f"{batch_size}x{num_gpus}bs") |
|
|
| if suffix: |
| parts.append(suffix) |
|
|
| |
| if not parts: |
| return "experiment" |
|
|
| return "-".join(parts) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description=( |
| "Generate and submit Slurm jobs for Audio Embeddings. " |
| "WandB run names are generated from the experiment config " |
| "(model, data, max_steps, batch_size x GPUs), plus optional suffix." |
| ) |
| ) |
| parser.add_argument( |
| "experiment", |
| type=str, |
| help="Experiment config path (e.g., audio_jepa/baseline)", |
| ) |
| parser.add_argument( |
| "--gpus", type=int, default=1, help="Number of GPUs to request (default: 1)" |
| ) |
| parser.add_argument( |
| "--time", |
| type=str, |
| default="20:00:00", |
| help="Time limit (HH:MM:SS) (default: 20:00:00)", |
| ) |
| parser.add_argument( |
| "--suffix", |
| type=str, |
| help=( |
| "Optional suffix for WandB run name. " |
| "Base name is derived from the experiment config: " |
| "model + data + max_steps (k/m) + batch_size x GPUs." |
| ), |
| ) |
| parser.add_argument( |
| "--dry-run", |
| action="store_true", |
| help="Print the generated script without submitting", |
| ) |
| parser.add_argument( |
| "--ffmpeg-module", |
| type=str, |
| default="ffmpeg/6.1.1", |
| help=( |
| "FFmpeg environment module to load in the job script " |
| "(default: ffmpeg/6.1.1)." |
| ), |
| ) |
|
|
| args, unknown = parser.parse_known_args() |
|
|
| |
| if args.gpus > 1: |
| strategy = "ddp" |
| |
| |
| extra_args_list = ["++trainer.sync_batchnorm=True"] |
| else: |
| strategy = "auto" |
| extra_args_list = [] |
|
|
| |
| if unknown: |
| for arg in unknown: |
| if arg.startswith(("+", "~")): |
| extra_args_list.append(arg) |
| elif "=" in arg: |
| |
| |
| |
| extra_args_list.append("++" + arg) |
| else: |
| extra_args_list.append(arg) |
|
|
| extra_args = " ".join(extra_args_list) |
|
|
| |
|
|
| |
| workdir = os.path.abspath(os.getcwd()) |
|
|
| |
| |
| config_path = os.path.join( |
| workdir, "configs", "experiment", f"{args.experiment}.yaml" |
| ) |
| wandb_name = generate_wandb_name(config_path, args.gpus, args.suffix) |
|
|
| |
| job_name = wandb_name |
|
|
| |
| try: |
| slurm_time_td = parse_slurm_time(args.time) |
| qos = select_qos(slurm_time_td) |
| buffer_time = timedelta(minutes=10) |
|
|
| |
| if slurm_time_td > buffer_time: |
| max_time_td = slurm_time_td - buffer_time |
| else: |
| print( |
| f"Warning: Requested time {args.time} is less than buffer (10m). Using full time." |
| ) |
| max_time_td = slurm_time_td |
|
|
| max_time_str = format_timedelta(max_time_td) |
| except Exception as e: |
| raise ValueError(f"Invalid --time value '{args.time}': {e}") from e |
|
|
| |
| script_content = TEMPLATE.format( |
| job_name=job_name, |
| qos=qos, |
| time=args.time, |
| gpus=args.gpus, |
| workdir=workdir, |
| experiment=args.experiment, |
| strategy=strategy, |
| wandb_name=wandb_name, |
| max_time=max_time_str, |
| ffmpeg_module=args.ffmpeg_module, |
| extra_args=extra_args, |
| ) |
|
|
| |
| if args.dry_run: |
| print("--- Dry Run: Generated Slurm Script ---") |
| print(script_content) |
| print("---------------------------------------") |
| return |
|
|
| |
| |
| script_dir = os.path.join(workdir, "slurm_scripts", ".generated") |
| os.makedirs(script_dir, exist_ok=True) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = os.path.join(script_dir, f"submit_{job_name}_{timestamp}.slurm") |
|
|
| with open(filename, "w") as f: |
| f.write(script_content) |
|
|
| print(f"Generated script: {filename}") |
|
|
| |
| try: |
| |
| result = subprocess.run( |
| ["sbatch", filename], check=True, capture_output=True, text=True |
| ) |
| print(f"Submission successful: {result.stdout.strip()}") |
| except subprocess.CalledProcessError as e: |
| print("Error: Submission failed!") |
| print(f"Stderr: {e.stderr}") |
| |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|