ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
#!/usr/bin/env python3
import argparse
import os
import re
import subprocess
from datetime import datetime, timedelta
# Slurm Script Template
# Adapt directives based on your cluster configuration
TEMPLATE = """#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --account=ojz@h100
#SBATCH --constraint=h100
#SBATCH --qos={qos}
#SBATCH --time={time}
#SBATCH --nodes=1
#SBATCH --ntasks-per-node={gpus}
#SBATCH --gres=gpu:{gpus}
#SBATCH --cpus-per-task=24
#SBATCH --hint=nomultithread
#SBATCH --output=logs/slurm/%x-%j.log
#SBATCH --error=logs/slurm/%x-%j.log
set -euxo pipefail
export MPLBACKEND=Agg
if ! command -v module >/dev/null 2>&1; then
source /etc/profile.d/modules.sh || true
fi
module load arch/h100
module load {ffmpeg_module}
FFMPEG_BIN=$(command -v ffmpeg || true)
if [ -n "$FFMPEG_BIN" ]; then
FFMPEG_ROOT=$(dirname "$(dirname "$FFMPEG_BIN")")
export LD_LIBRARY_PATH="${{FFMPEG_ROOT}}/lib:${{LD_LIBRARY_PATH}}"
fi
if [ -n "${{EBROOTFFMPEG:-}}" ]; then
export LD_LIBRARY_PATH="${{EBROOTFFMPEG}}/lib:${{LD_LIBRARY_PATH}}"
fi
cd {workdir}
export PYTHONUNBUFFERED=1
export HYDRA_FULL_ERROR=1
export TMPDIR=$SCRATCH
export TEMP=$SCRATCH
export TMP=$SCRATCH
export PROJECT_ROOT={workdir}
# Ensure log directory exists
mkdir -p logs/slurm
source .venv/bin/activate
# Configuration Info
# Experiment: {experiment}
# GPUs: {gpus}
# Strategy: {strategy}
# WandB Name: {wandb_name}
# Trainer Max Time: {max_time}
# FFmpeg module: {ffmpeg_module}
echo "Starting job {job_name} on $(hostname)"
echo "Experiment: {experiment}"
echo "FFmpeg binary: $(command -v ffmpeg || echo 'not found')"
ffmpeg -version | head -n 1
srun .venv/bin/python -u -O src/train.py \\
experiment={experiment} \\
++trainer.devices={gpus} \\
++trainer.strategy={strategy} \\
++trainer.max_time="{max_time}" \\
++logger.wandb.name="{wandb_name}" \\
{extra_args}
"""
def parse_slurm_time(time_str):
"""Parses Slurm time string into a timedelta object.
Formats: "MM", "MM:SS", "HH:MM:SS", "D-HH", "D-HH:MM", "D-HH:MM:SS"
"""
days = 0
if "-" in time_str:
days_str, time_str = time_str.split("-")
days = int(days_str)
parts = list(map(int, time_str.split(":")))
if len(parts) == 1: # MM
minutes = parts[0]
hours = 0
seconds = 0
elif len(parts) == 2: # MM:SS
minutes, seconds = parts
hours = 0
elif len(parts) == 3: # HH:MM:SS
hours, minutes, seconds = parts
else:
raise ValueError(f"Invalid time format: {time_str}")
return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
def format_timedelta(td):
"""Formats timedelta back to DD:HH:MM:SS string for Lightning"""
total_seconds = int(td.total_seconds())
days, remainder = divmod(total_seconds, 86400)
hours, remainder = divmod(remainder, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{days:02}:{hours:02}:{minutes:02}:{seconds:02}"
def parse_config_value(content, pattern):
match = re.search(pattern, content)
return match.group(1).strip() if match else None
def select_qos(time_limit: timedelta) -> str:
two_hours = timedelta(hours=2)
twenty_hours = timedelta(hours=20)
one_hundred_hours = timedelta(hours=100)
if time_limit <= two_hours:
return "qos_gpu_h100-dev"
if time_limit <= twenty_hours:
return "qos_gpu_h100-t3"
if time_limit <= one_hundred_hours:
return "qos_gpu_h100-t4"
raise ValueError(
"Requested time exceeds maximum supported QoS window (100h). "
"Please request 100:00:00 or less."
)
def format_steps(steps_str):
if not steps_str or not steps_str.isdigit():
return steps_str
steps = int(steps_str)
if steps >= 1000000:
return f"{steps // 1000000}m"
if steps >= 1000:
return f"{steps // 1000}k"
return str(steps)
def generate_wandb_name(config_path, num_gpus, suffix=None):
try:
with open(config_path, "r") as f:
content = f.read()
except FileNotFoundError:
print(
f"Warning: Config file not found at {config_path}. Cannot auto-generate name."
)
return "experiment"
# Extract values using regex
model = parse_config_value(content, r"override /model:\s*(\S+)")
dataset = parse_config_value(content, r"override /data:\s*(\S+)")
batch_size = parse_config_value(content, r"batch_size:\s*(\d+)")
max_steps = parse_config_value(content, r"max_steps:\s*(\d+)")
# Construct name parts
parts = []
if model:
parts.append(model)
if dataset:
parts.append(dataset)
if max_steps:
parts.append(format_steps(max_steps))
if batch_size:
parts.append(f"{batch_size}x{num_gpus}bs")
if suffix:
parts.append(suffix)
# Fallback if parsing failed completely
if not parts:
return "experiment"
return "-".join(parts)
def main():
parser = argparse.ArgumentParser(
description=(
"Generate and submit Slurm jobs for Audio Embeddings. "
"WandB run names are generated from the experiment config "
"(model, data, max_steps, batch_size x GPUs), plus optional suffix."
)
)
parser.add_argument(
"experiment",
type=str,
help="Experiment config path (e.g., audio_jepa/baseline)",
)
parser.add_argument(
"--gpus", type=int, default=1, help="Number of GPUs to request (default: 1)"
)
parser.add_argument(
"--time",
type=str,
default="20:00:00",
help="Time limit (HH:MM:SS) (default: 20:00:00)",
)
parser.add_argument(
"--suffix",
type=str,
help=(
"Optional suffix for WandB run name. "
"Base name is derived from the experiment config: "
"model + data + max_steps (k/m) + batch_size x GPUs."
),
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print the generated script without submitting",
)
parser.add_argument(
"--ffmpeg-module",
type=str,
default="ffmpeg/6.1.1",
help=(
"FFmpeg environment module to load in the job script "
"(default: ffmpeg/6.1.1)."
),
)
args, unknown = parser.parse_known_args()
# 1. Configuration Logic
if args.gpus > 1:
strategy = "ddp"
# Sync BatchNorm is usually recommended for DDP
# Using +trainer.sync_batchnorm to ensure we append it even if it doesn't exist
extra_args_list = ["++trainer.sync_batchnorm=True"]
else:
strategy = "auto"
extra_args_list = []
# Append any unknown arguments passed to the script (e.g. model.rq_lambda=0.5)
if unknown:
for arg in unknown:
if arg.startswith(("+", "~")):
extra_args_list.append(arg)
elif "=" in arg:
# If it's an assignment, use ++ to Force Add/Override
# This prevents "ConfigAttributeError" if the key isn't in the struct
# and works fine if it IS in the struct.
extra_args_list.append("++" + arg)
else:
extra_args_list.append(arg)
extra_args = " ".join(extra_args_list)
# Get absolute path of current working directory
# Get absolute path of current working directory
workdir = os.path.abspath(os.getcwd())
# 2. Generate WandB Name
# Assume config is in configs/experiment/{experiment}.yaml
config_path = os.path.join(
workdir, "configs", "experiment", f"{args.experiment}.yaml"
)
wandb_name = generate_wandb_name(config_path, args.gpus, args.suffix)
# Use WandB name as Job Name (consistent naming)
job_name = wandb_name
# 3. Select QoS and calculate Trainer Max Time (Time - 10 minutes)
try:
slurm_time_td = parse_slurm_time(args.time)
qos = select_qos(slurm_time_td)
buffer_time = timedelta(minutes=10)
# Ensure we don't go negative
if slurm_time_td > buffer_time:
max_time_td = slurm_time_td - buffer_time
else:
print(
f"Warning: Requested time {args.time} is less than buffer (10m). Using full time."
)
max_time_td = slurm_time_td
max_time_str = format_timedelta(max_time_td)
except Exception as e:
raise ValueError(f"Invalid --time value '{args.time}': {e}") from e
# 4. Fill Template
script_content = TEMPLATE.format(
job_name=job_name,
qos=qos,
time=args.time,
gpus=args.gpus,
workdir=workdir,
experiment=args.experiment,
strategy=strategy,
wandb_name=wandb_name,
max_time=max_time_str,
ffmpeg_module=args.ffmpeg_module,
extra_args=extra_args,
)
# 5. Handle Dry Run
if args.dry_run:
print("--- Dry Run: Generated Slurm Script ---")
print(script_content)
print("---------------------------------------")
return
# 4. Write to Temporary File
# Create a hidden temp directory for scripts if it doesn't exist
script_dir = os.path.join(workdir, "slurm_scripts", ".generated")
os.makedirs(script_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(script_dir, f"submit_{job_name}_{timestamp}.slurm")
with open(filename, "w") as f:
f.write(script_content)
print(f"Generated script: {filename}")
# 5. Submit to Slurm
try:
# Submit the script
result = subprocess.run(
["sbatch", filename], check=True, capture_output=True, text=True
)
print(f"Submission successful: {result.stdout.strip()}")
except subprocess.CalledProcessError as e:
print("Error: Submission failed!")
print(f"Stderr: {e.stderr}")
# Optionally delete the failed script? Keeping it for debug is usually better.
if __name__ == "__main__":
main()