nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import base64
import json
import os
import subprocess
import sys
import time
from pathlib import Path
def _is_json_scalar(value):
return value is None or isinstance(value, (bool, int, float, str))
def _to_json_safe(value):
if _is_json_scalar(value):
return value
if isinstance(value, list):
return [_to_json_safe(item) for item in value]
if isinstance(value, tuple):
return [_to_json_safe(item) for item in value]
if isinstance(value, dict):
converted = {}
for key in value:
converted[key] = _to_json_safe(value[key])
return converted
if isinstance(value, Path):
return str(value)
return str(value)
def _should_auto_run_modal(args):
if "PROTIFY_JOB_ID" in os.environ and os.environ["PROTIFY_JOB_ID"] != "":
return False
if args.replay_path is not None:
return False
if not args.modal_cli_credentials_provided:
return False
return args.modal_token_id is not None and args.modal_token_secret is not None
def _modal_subprocess_env(args):
env = os.environ.copy()
env["MODAL_TOKEN_ID"] = args.modal_token_id
env["MODAL_TOKEN_SECRET"] = args.modal_token_secret
env["PYTHONIOENCODING"] = "utf-8"
env["PYTHONUTF8"] = "1"
return env
def _repo_root():
return Path(__file__).resolve().parents[2]
def _deploy_modal_backend(args):
repo_root = _repo_root()
backend_path = repo_root / "src" / "protify" / "modal_backend.py"
assert backend_path.exists(), f"Modal backend not found at {backend_path}"
app_name = "protify-backend"
env = _modal_subprocess_env(args)
primary_command = [sys.executable, "-m", "modal", "deploy", str(backend_path), "--name", app_name]
try:
process = subprocess.run(
primary_command,
cwd=str(repo_root),
env=env,
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
)
except FileNotFoundError:
fallback_command = ["modal", "deploy", str(backend_path), "--name", app_name]
process = subprocess.run(
fallback_command,
cwd=str(repo_root),
env=env,
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
)
if process.returncode != 0:
stderr_text = process.stderr if process.stderr is not None else ""
stdout_text = process.stdout if process.stdout is not None else ""
combined_output = f"{stdout_text}\n{stderr_text}".strip()
if "No module named modal" in combined_output:
raise RuntimeError("Modal is not installed in this Python environment. Install it with: py -m pip install modal")
raise RuntimeError(f"Modal deploy failed:\n{combined_output}")
stdout_text = process.stdout if process.stdout is not None else ""
if stdout_text:
print(stdout_text[-4000:])
def _build_modal_config_from_args(args):
config = {}
excluded_keys = {
"modal_token_id",
"modal_token_secret",
"modal_api_key",
"modal_cli_credentials_provided",
"rebuild_modal",
"delete_modal_embeddings",
}
for key in args.__dict__:
if key in excluded_keys:
continue
config[key] = _to_json_safe(args.__dict__[key])
config["replay_path"] = None
return config
def _save_modal_artifacts(result_payload, output_root, job_id):
output_root_path = Path(output_root)
job_dir = output_root_path / job_id
job_dir.mkdir(parents=True, exist_ok=True)
files_payload = result_payload["files"] if "files" in result_payload else {}
for rel_path in files_payload:
local_path = job_dir / Path(rel_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
with open(local_path, "w", encoding="utf-8") as file:
file.write(files_payload[rel_path])
images_payload = result_payload["images"] if "images" in result_payload else {}
for rel_path in images_payload:
image_info = images_payload[rel_path]
if "data" not in image_info:
continue
local_path = job_dir / Path(rel_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
image_bytes = base64.b64decode(image_info["data"])
with open(local_path, "wb") as file:
file.write(image_bytes)
summary_path = job_dir / "modal_fetch_summary.json"
with open(summary_path, "w", encoding="utf-8") as file:
json.dump(result_payload, file, indent=2)
return str(job_dir)
def _coerce_modal_terminal_payload(remote_result):
if isinstance(remote_result, dict):
payload = dict(remote_result)
if "status" not in payload:
if "success" in payload and payload["success"]:
payload["status"] = "SUCCESS"
elif "success" in payload and not payload["success"]:
payload["status"] = "FAILED"
else:
payload["status"] = "SUCCESS"
return payload
return {"status": "SUCCESS"}
def _run_on_modal_cli(args):
try:
import modal
except Exception as error:
raise RuntimeError("Modal SDK is required for CLI remote execution. Install with: py -m pip install modal") from error
app_name = "protify-backend"
gpu_type = "A10"
if "modal_gpu_type" in args.__dict__ and args.modal_gpu_type is not None:
gpu_type = args.modal_gpu_type
timeout_seconds = 86400
if "modal_timeout_seconds" in args.__dict__ and args.modal_timeout_seconds is not None:
timeout_seconds = args.modal_timeout_seconds
poll_interval_seconds = 5
if "modal_poll_interval_seconds" in args.__dict__ and args.modal_poll_interval_seconds is not None:
poll_interval_seconds = args.modal_poll_interval_seconds
log_tail_chars = 5000
if "modal_log_tail_chars" in args.__dict__ and args.modal_log_tail_chars is not None:
log_tail_chars = args.modal_log_tail_chars
max_stale_heartbeat_seconds = 600
if "modal_max_stale_heartbeat_seconds" in args.__dict__ and args.modal_max_stale_heartbeat_seconds is not None:
max_stale_heartbeat_seconds = args.modal_max_stale_heartbeat_seconds
artifacts_root = "modal_artifacts"
if "modal_artifacts_dir" in args.__dict__ and args.modal_artifacts_dir is not None:
artifacts_root = args.modal_artifacts_dir
if args.rebuild_modal:
print("Rebuilding Modal backend due to --rebuild_modal ...")
_deploy_modal_backend(args)
config = _build_modal_config_from_args(args)
submit_fn = modal.Function.from_name(app_name, "submit_protify_job")
status_fn = modal.Function.from_name(app_name, "get_job_status")
log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta")
results_fn = modal.Function.from_name(app_name, "get_results")
delete_embeddings_fn = modal.Function.from_name(app_name, "delete_modal_embeddings")
if args.delete_modal_embeddings:
print("Deleting Modal embedding cache due to --delete_modal_embeddings ...")
try:
delete_embeddings_payload = delete_embeddings_fn.remote()
except Exception:
print("Modal embedding delete failed before app/function lookup succeeded; attempting deploy then retry...")
_deploy_modal_backend(args)
submit_fn = modal.Function.from_name(app_name, "submit_protify_job")
status_fn = modal.Function.from_name(app_name, "get_job_status")
log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta")
results_fn = modal.Function.from_name(app_name, "get_results")
delete_embeddings_fn = modal.Function.from_name(app_name, "delete_modal_embeddings")
delete_embeddings_payload = delete_embeddings_fn.remote()
if isinstance(delete_embeddings_payload, dict) and "message" in delete_embeddings_payload:
print(delete_embeddings_payload["message"])
has_dataset_run = len(args.data_names) > 0 or len(args.data_dirs) > 0
if not has_dataset_run and not args.proteingym:
return 0
try:
submit_result = submit_fn.remote(
config=config,
gpu_type=gpu_type,
hf_token=args.hf_token,
wandb_api_key=args.wandb_api_key,
synthyra_api_key=args.synthyra_api_key,
timeout_seconds=timeout_seconds,
)
except Exception:
print("Modal submit failed before app/function lookup succeeded; attempting deploy then retry...")
_deploy_modal_backend(args)
submit_fn = modal.Function.from_name(app_name, "submit_protify_job")
status_fn = modal.Function.from_name(app_name, "get_job_status")
log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta")
results_fn = modal.Function.from_name(app_name, "get_results")
submit_result = submit_fn.remote(
config=config,
gpu_type=gpu_type,
hf_token=args.hf_token,
wandb_api_key=args.wandb_api_key,
synthyra_api_key=args.synthyra_api_key,
timeout_seconds=timeout_seconds,
)
assert isinstance(submit_result, dict), "Modal submit response is not a dictionary."
assert "job_id" in submit_result, "Modal submit response missing job_id."
job_id = submit_result["job_id"]
function_call_id = submit_result["function_call_id"] if "function_call_id" in submit_result else None
print(f"Modal job submitted: {job_id}")
if function_call_id is not None:
print(f"Modal function call id: {function_call_id}")
terminal_states = {"SUCCESS", "FAILED", "TERMINATED", "TIMEOUT"}
final_status_payload = None
poll_start_time = time.time()
max_poll_seconds = int(timeout_seconds) + 900
status_print_interval_seconds = 15
last_status_print_time = 0.0
last_status_line = ""
missing_status_count = 0
log_offset = 0
function_call = None
if function_call_id is not None:
function_call = modal.FunctionCall.from_id(function_call_id)
def _emit_remote_logs():
nonlocal log_offset
delta_payload = log_delta_fn.remote(job_id=job_id, offset=log_offset, max_chars=log_tail_chars)
if isinstance(delta_payload, dict):
if "next_offset" in delta_payload and isinstance(delta_payload["next_offset"], int):
log_offset = delta_payload["next_offset"]
if "chunk" in delta_payload and delta_payload["chunk"]:
sys.stdout.write(delta_payload["chunk"])
sys.stdout.flush()
while True:
_emit_remote_logs()
status_payload = status_fn.remote(job_id=job_id)
assert isinstance(status_payload, dict), "Modal status response is not a dictionary."
if "success" in status_payload and status_payload["success"]:
missing_status_count = 0
status_value = status_payload["status"] if "status" in status_payload else "UNKNOWN"
phase_value = status_payload["phase"] if "phase" in status_payload else "N/A"
heartbeat_age = status_payload["heartbeat_age_seconds"] if "heartbeat_age_seconds" in status_payload else None
heartbeat_text = "N/A" if heartbeat_age is None else f"{heartbeat_age:.1f}s"
status_line = f"[Modal] status={status_value} phase={phase_value} heartbeat_age={heartbeat_text}"
if status_value in terminal_states:
final_status_payload = dict(status_payload)
break
else:
missing_status_count += 1
status_line = "[Modal] state=queued_or_initializing"
if missing_status_count % 6 == 0 and "error" in status_payload and status_payload["error"]:
status_line = f"[Modal] state=queued_or_initializing detail={status_payload['error']}"
now = time.time()
if status_line != last_status_line or (now - last_status_print_time) >= status_print_interval_seconds:
print(status_line)
last_status_line = status_line
last_status_print_time = now
if function_call is not None:
try:
remote_result = function_call.get(timeout=0)
final_status_payload = _coerce_modal_terminal_payload(remote_result)
if "phase" not in final_status_payload and "phase" in status_payload:
final_status_payload["phase"] = status_payload["phase"]
break
except TimeoutError:
pass
except Exception as error:
final_status_payload = {"status": "FAILED", "error": f"Function call failed: {error}"}
break
elapsed_seconds = now - poll_start_time
if elapsed_seconds > max_poll_seconds:
final_status_payload = {
"status": "TIMEOUT",
"phase": "poll_timeout",
"error": f"Polling exceeded timeout window ({max_poll_seconds} seconds).",
}
break
if "success" in status_payload and status_payload["success"] and "heartbeat_age_seconds" in status_payload:
heartbeat_age = status_payload["heartbeat_age_seconds"]
if heartbeat_age is not None and heartbeat_age > max_stale_heartbeat_seconds and function_call is None:
final_status_payload = {
"status": "FAILED",
"phase": "stale_heartbeat",
"error": f"Heartbeat stale for {heartbeat_age:.1f}s with no function_call_id available.",
}
break
time.sleep(max(1, int(poll_interval_seconds)))
final_delta_payload = log_delta_fn.remote(job_id=job_id, offset=log_offset, max_chars=log_tail_chars * 8)
if isinstance(final_delta_payload, dict):
if "chunk" in final_delta_payload and final_delta_payload["chunk"]:
sys.stdout.write(final_delta_payload["chunk"])
sys.stdout.flush()
try:
results_payload = results_fn.remote(job_id=job_id)
except Exception as error:
results_payload = {"success": False, "error": str(error)}
if isinstance(results_payload, dict) and "success" in results_payload and results_payload["success"]:
artifacts_dir = _save_modal_artifacts(results_payload, artifacts_root, job_id)
print(f"Modal artifacts saved to {artifacts_dir}")
if final_status_payload is None:
final_status_payload = {"status": "FAILED", "error": "No terminal status was resolved."}
final_status = final_status_payload["status"] if "status" in final_status_payload else "FAILED"
if final_status != "SUCCESS":
if "error" in final_status_payload and final_status_payload["error"]:
print(f"Modal job failed: {final_status_payload['error']}")
return 1
return 0