File size: 6,744 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """Launch N RunPod A40 pods, deploy the codebase, kick off training.
Usage:
python scripts/runpod_launch.py --models A B C F --gpu A40 \
--image runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04
For each model letter:
1. create pod
2. wait for SSH
3. rsync repo + .env via scp
4. run pod_bootstrap.sh on the pod (in tmux/nohup)
5. record pod id + run name in runs/launch_manifest.json
Polling/log retrieval is left to scripts/runpod_status.py.
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
RUNPOD_API_KEY = os.environ["RUNPOD_API_KEY"]
GPU_IDS = {
"A40": "NVIDIA A40",
"A6000": "NVIDIA RTX A6000",
"A100": "NVIDIA A100-SXM4-80GB",
"H100": "NVIDIA H100 80GB HBM3",
}
DEFAULT_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
def runpodctl(args: list[str], capture: bool = True) -> str:
env = {**os.environ, "RUNPOD_API_KEY": RUNPOD_API_KEY}
res = subprocess.run(
["runpodctl", *args], env=env, capture_output=capture, text=True
)
if res.returncode != 0:
raise RuntimeError(f"runpodctl {' '.join(args)} failed: {res.stderr}\n{res.stdout}")
return res.stdout
def create_pod(name: str, gpu_id: str, image: str, container_disk: int = 50,
volume_gb: int = 100) -> dict:
out = runpodctl([
"pod", "create",
"--name", name,
"--gpu-id", gpu_id,
"--gpu-count", "1",
"--image", image,
"--cloud-type", "COMMUNITY",
"--container-disk-in-gb", str(container_disk),
"--volume-in-gb", str(volume_gb),
"--volume-mount-path", "/workspace",
"--ports", "22/tcp",
"--ssh",
])
pod = json.loads(out)
return pod
def wait_for_ssh(pod_id: str, timeout: int = 600) -> tuple[str, int]:
start = time.time()
last_err = ""
while time.time() - start < timeout:
try:
info = json.loads(runpodctl(["ssh", "info", pod_id]))
host = info.get("publicIp") or info.get("ip")
port = info.get("port") or info.get("sshPort")
if host and port:
return host, int(port)
except Exception as e:
last_err = str(e)
time.sleep(15)
raise TimeoutError(f"SSH not ready for {pod_id}: {last_err}")
def ssh(host: str, port: int, cmd: str, user: str = "root", timeout: int = 60) -> str:
res = subprocess.run([
"ssh", "-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=15",
"-p", str(port),
f"{user}@{host}", cmd,
], capture_output=True, text=True, timeout=timeout)
if res.returncode != 0:
raise RuntimeError(f"ssh {host}:{port} {cmd!r} failed: {res.stderr}")
return res.stdout
def scp(host: str, port: int, local_path: Path, remote_path: str, user: str = "root") -> None:
cmd = ["scp", "-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-P", str(port)]
if local_path.is_dir():
cmd.append("-r")
cmd.extend([str(local_path), f"{user}@{host}:{remote_path}"])
res = subprocess.run(cmd, capture_output=True, text=True, timeout=900)
if res.returncode != 0:
raise RuntimeError(f"scp {local_path} -> {host}:{remote_path} failed: {res.stderr}")
def deploy_and_launch(host: str, port: int, model: str, run_name: str, repo_root: Path) -> None:
# build a tarball excluding bulky dirs
with tempfile.TemporaryDirectory() as td:
tar = Path(td) / "physiojepa.tar.gz"
excludes = [".venv", ".git", "__pycache__", "runs", "cache", "docs/figures",
"docs/paperes"]
excl_args = []
for e in excludes:
excl_args.extend(["--exclude", e])
subprocess.run(
["tar", "-czf", str(tar), *excl_args, "-C", str(repo_root.parent),
repo_root.name],
check=True,
)
scp(host, port, tar, "/workspace/physiojepa.tar.gz")
# also send .env
env_file = repo_root / ".env"
scp(host, port, env_file, "/workspace/.env")
ssh(host, port, "set -e; cd /workspace && rm -rf physiojepa && "
"tar -xzf physiojepa.tar.gz && rm physiojepa.tar.gz")
# background the bootstrap with nohup so SSH disconnect doesn't kill it
bootstrap = (
f"set -e; mkdir -p /workspace/runs; "
f"cd /workspace/physiojepa && chmod +x scripts/pod_bootstrap.sh && "
f"nohup bash scripts/pod_bootstrap.sh {model} {run_name} "
f"> /workspace/runs/{run_name}.bootstrap.log 2>&1 &"
f" disown; echo started; sleep 1"
)
ssh(host, port, bootstrap)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--models", nargs="+", default=["A", "B", "C", "F"])
ap.add_argument("--gpu", default="A40", choices=list(GPU_IDS.keys()))
ap.add_argument("--image", default=DEFAULT_IMAGE)
ap.add_argument("--repo_root", default=str(Path(__file__).resolve().parents[1]))
ap.add_argument("--manifest", default="runs/launch_manifest.json")
args = ap.parse_args()
repo_root = Path(args.repo_root)
Path(args.manifest).parent.mkdir(parents=True, exist_ok=True)
gpu_id = GPU_IDS[args.gpu]
manifest = []
for model in args.models:
run_name = f"e2_{model}_a40"
pod_name = f"pj-{model.lower()}-{int(time.time()) % 100000:05d}"
print(f"[launch] creating pod {pod_name} (model={model}, gpu={args.gpu})")
pod = create_pod(pod_name, gpu_id, args.image)
pod_id = pod.get("id") or pod.get("podId")
print(f"[launch] pod_id={pod_id}, waiting for SSH...")
try:
host, port = wait_for_ssh(pod_id)
except TimeoutError as e:
print(f"[launch] WARN: {e}; deleting pod and continuing")
try:
runpodctl(["pod", "delete", pod_id])
except Exception:
pass
continue
print(f"[launch] SSH up @ {host}:{port}, deploying code")
deploy_and_launch(host, port, model, run_name, repo_root)
manifest.append({"pod_id": pod_id, "pod_name": pod_name, "host": host,
"port": port, "model": model, "run_name": run_name,
"started_at": time.time()})
Path(args.manifest).write_text(json.dumps(manifest, indent=2))
print(f"[launch] {model} kicked off; manifest -> {args.manifest}")
print(f"[launch] all done. manifest:\n{Path(args.manifest).read_text()}")
if __name__ == "__main__":
main()
|