"""Launch N RunPod A40 pods, deploy the codebase, kick off training. Usage: python scripts/runpod_launch.py --models A B C F --gpu A40 \ --image runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04 For each model letter: 1. create pod 2. wait for SSH 3. rsync repo + .env via scp 4. run pod_bootstrap.sh on the pod (in tmux/nohup) 5. record pod id + run name in runs/launch_manifest.json Polling/log retrieval is left to scripts/runpod_status.py. """ from __future__ import annotations import argparse import json import os import shutil import subprocess import sys import tempfile import time from pathlib import Path from dotenv import load_dotenv load_dotenv() RUNPOD_API_KEY = os.environ["RUNPOD_API_KEY"] GPU_IDS = { "A40": "NVIDIA A40", "A6000": "NVIDIA RTX A6000", "A100": "NVIDIA A100-SXM4-80GB", "H100": "NVIDIA H100 80GB HBM3", } DEFAULT_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" def runpodctl(args: list[str], capture: bool = True) -> str: env = {**os.environ, "RUNPOD_API_KEY": RUNPOD_API_KEY} res = subprocess.run( ["runpodctl", *args], env=env, capture_output=capture, text=True ) if res.returncode != 0: raise RuntimeError(f"runpodctl {' '.join(args)} failed: {res.stderr}\n{res.stdout}") return res.stdout def create_pod(name: str, gpu_id: str, image: str, container_disk: int = 50, volume_gb: int = 100) -> dict: out = runpodctl([ "pod", "create", "--name", name, "--gpu-id", gpu_id, "--gpu-count", "1", "--image", image, "--cloud-type", "COMMUNITY", "--container-disk-in-gb", str(container_disk), "--volume-in-gb", str(volume_gb), "--volume-mount-path", "/workspace", "--ports", "22/tcp", "--ssh", ]) pod = json.loads(out) return pod def wait_for_ssh(pod_id: str, timeout: int = 600) -> tuple[str, int]: start = time.time() last_err = "" while time.time() - start < timeout: try: info = json.loads(runpodctl(["ssh", "info", pod_id])) host = info.get("publicIp") or info.get("ip") port = info.get("port") or info.get("sshPort") if host and port: return host, int(port) except Exception as e: last_err = str(e) time.sleep(15) raise TimeoutError(f"SSH not ready for {pod_id}: {last_err}") def ssh(host: str, port: int, cmd: str, user: str = "root", timeout: int = 60) -> str: res = subprocess.run([ "ssh", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=15", "-p", str(port), f"{user}@{host}", cmd, ], capture_output=True, text=True, timeout=timeout) if res.returncode != 0: raise RuntimeError(f"ssh {host}:{port} {cmd!r} failed: {res.stderr}") return res.stdout def scp(host: str, port: int, local_path: Path, remote_path: str, user: str = "root") -> None: cmd = ["scp", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-P", str(port)] if local_path.is_dir(): cmd.append("-r") cmd.extend([str(local_path), f"{user}@{host}:{remote_path}"]) res = subprocess.run(cmd, capture_output=True, text=True, timeout=900) if res.returncode != 0: raise RuntimeError(f"scp {local_path} -> {host}:{remote_path} failed: {res.stderr}") def deploy_and_launch(host: str, port: int, model: str, run_name: str, repo_root: Path) -> None: # build a tarball excluding bulky dirs with tempfile.TemporaryDirectory() as td: tar = Path(td) / "physiojepa.tar.gz" excludes = [".venv", ".git", "__pycache__", "runs", "cache", "docs/figures", "docs/paperes"] excl_args = [] for e in excludes: excl_args.extend(["--exclude", e]) subprocess.run( ["tar", "-czf", str(tar), *excl_args, "-C", str(repo_root.parent), repo_root.name], check=True, ) scp(host, port, tar, "/workspace/physiojepa.tar.gz") # also send .env env_file = repo_root / ".env" scp(host, port, env_file, "/workspace/.env") ssh(host, port, "set -e; cd /workspace && rm -rf physiojepa && " "tar -xzf physiojepa.tar.gz && rm physiojepa.tar.gz") # background the bootstrap with nohup so SSH disconnect doesn't kill it bootstrap = ( f"set -e; mkdir -p /workspace/runs; " f"cd /workspace/physiojepa && chmod +x scripts/pod_bootstrap.sh && " f"nohup bash scripts/pod_bootstrap.sh {model} {run_name} " f"> /workspace/runs/{run_name}.bootstrap.log 2>&1 &" f" disown; echo started; sleep 1" ) ssh(host, port, bootstrap) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--models", nargs="+", default=["A", "B", "C", "F"]) ap.add_argument("--gpu", default="A40", choices=list(GPU_IDS.keys())) ap.add_argument("--image", default=DEFAULT_IMAGE) ap.add_argument("--repo_root", default=str(Path(__file__).resolve().parents[1])) ap.add_argument("--manifest", default="runs/launch_manifest.json") args = ap.parse_args() repo_root = Path(args.repo_root) Path(args.manifest).parent.mkdir(parents=True, exist_ok=True) gpu_id = GPU_IDS[args.gpu] manifest = [] for model in args.models: run_name = f"e2_{model}_a40" pod_name = f"pj-{model.lower()}-{int(time.time()) % 100000:05d}" print(f"[launch] creating pod {pod_name} (model={model}, gpu={args.gpu})") pod = create_pod(pod_name, gpu_id, args.image) pod_id = pod.get("id") or pod.get("podId") print(f"[launch] pod_id={pod_id}, waiting for SSH...") try: host, port = wait_for_ssh(pod_id) except TimeoutError as e: print(f"[launch] WARN: {e}; deleting pod and continuing") try: runpodctl(["pod", "delete", pod_id]) except Exception: pass continue print(f"[launch] SSH up @ {host}:{port}, deploying code") deploy_and_launch(host, port, model, run_name, repo_root) manifest.append({"pod_id": pod_id, "pod_name": pod_name, "host": host, "port": port, "model": model, "run_name": run_name, "started_at": time.time()}) Path(args.manifest).write_text(json.dumps(manifest, indent=2)) print(f"[launch] {model} kicked off; manifest -> {args.manifest}") print(f"[launch] all done. manifest:\n{Path(args.manifest).read_text()}") if __name__ == "__main__": main()