import base64 import json import os import subprocess import sys import time from pathlib import Path def _is_json_scalar(value): return value is None or isinstance(value, (bool, int, float, str)) def _to_json_safe(value): if _is_json_scalar(value): return value if isinstance(value, list): return [_to_json_safe(item) for item in value] if isinstance(value, tuple): return [_to_json_safe(item) for item in value] if isinstance(value, dict): converted = {} for key in value: converted[key] = _to_json_safe(value[key]) return converted if isinstance(value, Path): return str(value) return str(value) def _should_auto_run_modal(args): if "PROTIFY_JOB_ID" in os.environ and os.environ["PROTIFY_JOB_ID"] != "": return False if args.replay_path is not None: return False if not args.modal_cli_credentials_provided: return False return args.modal_token_id is not None and args.modal_token_secret is not None def _modal_subprocess_env(args): env = os.environ.copy() env["MODAL_TOKEN_ID"] = args.modal_token_id env["MODAL_TOKEN_SECRET"] = args.modal_token_secret env["PYTHONIOENCODING"] = "utf-8" env["PYTHONUTF8"] = "1" return env def _repo_root(): return Path(__file__).resolve().parents[2] def _deploy_modal_backend(args): repo_root = _repo_root() backend_path = repo_root / "src" / "protify" / "modal_backend.py" assert backend_path.exists(), f"Modal backend not found at {backend_path}" app_name = "protify-backend" env = _modal_subprocess_env(args) primary_command = [sys.executable, "-m", "modal", "deploy", str(backend_path), "--name", app_name] try: process = subprocess.run( primary_command, cwd=str(repo_root), env=env, capture_output=True, text=True, encoding="utf-8", errors="replace", ) except FileNotFoundError: fallback_command = ["modal", "deploy", str(backend_path), "--name", app_name] process = subprocess.run( fallback_command, cwd=str(repo_root), env=env, capture_output=True, text=True, encoding="utf-8", errors="replace", ) if process.returncode != 0: stderr_text = process.stderr if process.stderr is not None else "" stdout_text = process.stdout if process.stdout is not None else "" combined_output = f"{stdout_text}\n{stderr_text}".strip() if "No module named modal" in combined_output: raise RuntimeError("Modal is not installed in this Python environment. Install it with: py -m pip install modal") raise RuntimeError(f"Modal deploy failed:\n{combined_output}") stdout_text = process.stdout if process.stdout is not None else "" if stdout_text: print(stdout_text[-4000:]) def _build_modal_config_from_args(args): config = {} excluded_keys = { "modal_token_id", "modal_token_secret", "modal_api_key", "modal_cli_credentials_provided", "rebuild_modal", "delete_modal_embeddings", } for key in args.__dict__: if key in excluded_keys: continue config[key] = _to_json_safe(args.__dict__[key]) config["replay_path"] = None return config def _save_modal_artifacts(result_payload, output_root, job_id): output_root_path = Path(output_root) job_dir = output_root_path / job_id job_dir.mkdir(parents=True, exist_ok=True) files_payload = result_payload["files"] if "files" in result_payload else {} for rel_path in files_payload: local_path = job_dir / Path(rel_path) local_path.parent.mkdir(parents=True, exist_ok=True) with open(local_path, "w", encoding="utf-8") as file: file.write(files_payload[rel_path]) images_payload = result_payload["images"] if "images" in result_payload else {} for rel_path in images_payload: image_info = images_payload[rel_path] if "data" not in image_info: continue local_path = job_dir / Path(rel_path) local_path.parent.mkdir(parents=True, exist_ok=True) image_bytes = base64.b64decode(image_info["data"]) with open(local_path, "wb") as file: file.write(image_bytes) summary_path = job_dir / "modal_fetch_summary.json" with open(summary_path, "w", encoding="utf-8") as file: json.dump(result_payload, file, indent=2) return str(job_dir) def _coerce_modal_terminal_payload(remote_result): if isinstance(remote_result, dict): payload = dict(remote_result) if "status" not in payload: if "success" in payload and payload["success"]: payload["status"] = "SUCCESS" elif "success" in payload and not payload["success"]: payload["status"] = "FAILED" else: payload["status"] = "SUCCESS" return payload return {"status": "SUCCESS"} def _run_on_modal_cli(args): try: import modal except Exception as error: raise RuntimeError("Modal SDK is required for CLI remote execution. Install with: py -m pip install modal") from error app_name = "protify-backend" gpu_type = "A10" if "modal_gpu_type" in args.__dict__ and args.modal_gpu_type is not None: gpu_type = args.modal_gpu_type timeout_seconds = 86400 if "modal_timeout_seconds" in args.__dict__ and args.modal_timeout_seconds is not None: timeout_seconds = args.modal_timeout_seconds poll_interval_seconds = 5 if "modal_poll_interval_seconds" in args.__dict__ and args.modal_poll_interval_seconds is not None: poll_interval_seconds = args.modal_poll_interval_seconds log_tail_chars = 5000 if "modal_log_tail_chars" in args.__dict__ and args.modal_log_tail_chars is not None: log_tail_chars = args.modal_log_tail_chars max_stale_heartbeat_seconds = 600 if "modal_max_stale_heartbeat_seconds" in args.__dict__ and args.modal_max_stale_heartbeat_seconds is not None: max_stale_heartbeat_seconds = args.modal_max_stale_heartbeat_seconds artifacts_root = "modal_artifacts" if "modal_artifacts_dir" in args.__dict__ and args.modal_artifacts_dir is not None: artifacts_root = args.modal_artifacts_dir if args.rebuild_modal: print("Rebuilding Modal backend due to --rebuild_modal ...") _deploy_modal_backend(args) config = _build_modal_config_from_args(args) submit_fn = modal.Function.from_name(app_name, "submit_protify_job") status_fn = modal.Function.from_name(app_name, "get_job_status") log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta") results_fn = modal.Function.from_name(app_name, "get_results") delete_embeddings_fn = modal.Function.from_name(app_name, "delete_modal_embeddings") if args.delete_modal_embeddings: print("Deleting Modal embedding cache due to --delete_modal_embeddings ...") try: delete_embeddings_payload = delete_embeddings_fn.remote() except Exception: print("Modal embedding delete failed before app/function lookup succeeded; attempting deploy then retry...") _deploy_modal_backend(args) submit_fn = modal.Function.from_name(app_name, "submit_protify_job") status_fn = modal.Function.from_name(app_name, "get_job_status") log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta") results_fn = modal.Function.from_name(app_name, "get_results") delete_embeddings_fn = modal.Function.from_name(app_name, "delete_modal_embeddings") delete_embeddings_payload = delete_embeddings_fn.remote() if isinstance(delete_embeddings_payload, dict) and "message" in delete_embeddings_payload: print(delete_embeddings_payload["message"]) has_dataset_run = len(args.data_names) > 0 or len(args.data_dirs) > 0 if not has_dataset_run and not args.proteingym: return 0 try: submit_result = submit_fn.remote( config=config, gpu_type=gpu_type, hf_token=args.hf_token, wandb_api_key=args.wandb_api_key, synthyra_api_key=args.synthyra_api_key, timeout_seconds=timeout_seconds, ) except Exception: print("Modal submit failed before app/function lookup succeeded; attempting deploy then retry...") _deploy_modal_backend(args) submit_fn = modal.Function.from_name(app_name, "submit_protify_job") status_fn = modal.Function.from_name(app_name, "get_job_status") log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta") results_fn = modal.Function.from_name(app_name, "get_results") submit_result = submit_fn.remote( config=config, gpu_type=gpu_type, hf_token=args.hf_token, wandb_api_key=args.wandb_api_key, synthyra_api_key=args.synthyra_api_key, timeout_seconds=timeout_seconds, ) assert isinstance(submit_result, dict), "Modal submit response is not a dictionary." assert "job_id" in submit_result, "Modal submit response missing job_id." job_id = submit_result["job_id"] function_call_id = submit_result["function_call_id"] if "function_call_id" in submit_result else None print(f"Modal job submitted: {job_id}") if function_call_id is not None: print(f"Modal function call id: {function_call_id}") terminal_states = {"SUCCESS", "FAILED", "TERMINATED", "TIMEOUT"} final_status_payload = None poll_start_time = time.time() max_poll_seconds = int(timeout_seconds) + 900 status_print_interval_seconds = 15 last_status_print_time = 0.0 last_status_line = "" missing_status_count = 0 log_offset = 0 function_call = None if function_call_id is not None: function_call = modal.FunctionCall.from_id(function_call_id) def _emit_remote_logs(): nonlocal log_offset delta_payload = log_delta_fn.remote(job_id=job_id, offset=log_offset, max_chars=log_tail_chars) if isinstance(delta_payload, dict): if "next_offset" in delta_payload and isinstance(delta_payload["next_offset"], int): log_offset = delta_payload["next_offset"] if "chunk" in delta_payload and delta_payload["chunk"]: sys.stdout.write(delta_payload["chunk"]) sys.stdout.flush() while True: _emit_remote_logs() status_payload = status_fn.remote(job_id=job_id) assert isinstance(status_payload, dict), "Modal status response is not a dictionary." if "success" in status_payload and status_payload["success"]: missing_status_count = 0 status_value = status_payload["status"] if "status" in status_payload else "UNKNOWN" phase_value = status_payload["phase"] if "phase" in status_payload else "N/A" heartbeat_age = status_payload["heartbeat_age_seconds"] if "heartbeat_age_seconds" in status_payload else None heartbeat_text = "N/A" if heartbeat_age is None else f"{heartbeat_age:.1f}s" status_line = f"[Modal] status={status_value} phase={phase_value} heartbeat_age={heartbeat_text}" if status_value in terminal_states: final_status_payload = dict(status_payload) break else: missing_status_count += 1 status_line = "[Modal] state=queued_or_initializing" if missing_status_count % 6 == 0 and "error" in status_payload and status_payload["error"]: status_line = f"[Modal] state=queued_or_initializing detail={status_payload['error']}" now = time.time() if status_line != last_status_line or (now - last_status_print_time) >= status_print_interval_seconds: print(status_line) last_status_line = status_line last_status_print_time = now if function_call is not None: try: remote_result = function_call.get(timeout=0) final_status_payload = _coerce_modal_terminal_payload(remote_result) if "phase" not in final_status_payload and "phase" in status_payload: final_status_payload["phase"] = status_payload["phase"] break except TimeoutError: pass except Exception as error: final_status_payload = {"status": "FAILED", "error": f"Function call failed: {error}"} break elapsed_seconds = now - poll_start_time if elapsed_seconds > max_poll_seconds: final_status_payload = { "status": "TIMEOUT", "phase": "poll_timeout", "error": f"Polling exceeded timeout window ({max_poll_seconds} seconds).", } break if "success" in status_payload and status_payload["success"] and "heartbeat_age_seconds" in status_payload: heartbeat_age = status_payload["heartbeat_age_seconds"] if heartbeat_age is not None and heartbeat_age > max_stale_heartbeat_seconds and function_call is None: final_status_payload = { "status": "FAILED", "phase": "stale_heartbeat", "error": f"Heartbeat stale for {heartbeat_age:.1f}s with no function_call_id available.", } break time.sleep(max(1, int(poll_interval_seconds))) final_delta_payload = log_delta_fn.remote(job_id=job_id, offset=log_offset, max_chars=log_tail_chars * 8) if isinstance(final_delta_payload, dict): if "chunk" in final_delta_payload and final_delta_payload["chunk"]: sys.stdout.write(final_delta_payload["chunk"]) sys.stdout.flush() try: results_payload = results_fn.remote(job_id=job_id) except Exception as error: results_payload = {"success": False, "error": str(error)} if isinstance(results_payload, dict) and "success" in results_payload and results_payload["success"]: artifacts_dir = _save_modal_artifacts(results_payload, artifacts_root, job_id) print(f"Modal artifacts saved to {artifacts_dir}") if final_status_payload is None: final_status_payload = {"status": "FAILED", "error": "No terminal status was resolved."} final_status = final_status_payload["status"] if "status" in final_status_payload else "FAILED" if final_status != "SUCCESS": if "error" in final_status_payload and final_status_payload["error"]: print(f"Modal job failed: {final_status_payload['error']}") return 1 return 0