| import base64 | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| def _is_json_scalar(value): | |
| return value is None or isinstance(value, (bool, int, float, str)) | |
| def _to_json_safe(value): | |
| if _is_json_scalar(value): | |
| return value | |
| if isinstance(value, list): | |
| return [_to_json_safe(item) for item in value] | |
| if isinstance(value, tuple): | |
| return [_to_json_safe(item) for item in value] | |
| if isinstance(value, dict): | |
| converted = {} | |
| for key in value: | |
| converted[key] = _to_json_safe(value[key]) | |
| return converted | |
| if isinstance(value, Path): | |
| return str(value) | |
| return str(value) | |
| def _should_auto_run_modal(args): | |
| if "PROTIFY_JOB_ID" in os.environ and os.environ["PROTIFY_JOB_ID"] != "": | |
| return False | |
| if args.replay_path is not None: | |
| return False | |
| if not args.modal_cli_credentials_provided: | |
| return False | |
| return args.modal_token_id is not None and args.modal_token_secret is not None | |
| def _modal_subprocess_env(args): | |
| env = os.environ.copy() | |
| env["MODAL_TOKEN_ID"] = args.modal_token_id | |
| env["MODAL_TOKEN_SECRET"] = args.modal_token_secret | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| env["PYTHONUTF8"] = "1" | |
| return env | |
| def _repo_root(): | |
| return Path(__file__).resolve().parents[2] | |
| def _deploy_modal_backend(args): | |
| repo_root = _repo_root() | |
| backend_path = repo_root / "src" / "protify" / "modal_backend.py" | |
| assert backend_path.exists(), f"Modal backend not found at {backend_path}" | |
| app_name = "protify-backend" | |
| env = _modal_subprocess_env(args) | |
| primary_command = [sys.executable, "-m", "modal", "deploy", str(backend_path), "--name", app_name] | |
| try: | |
| process = subprocess.run( | |
| primary_command, | |
| cwd=str(repo_root), | |
| env=env, | |
| capture_output=True, | |
| text=True, | |
| encoding="utf-8", | |
| errors="replace", | |
| ) | |
| except FileNotFoundError: | |
| fallback_command = ["modal", "deploy", str(backend_path), "--name", app_name] | |
| process = subprocess.run( | |
| fallback_command, | |
| cwd=str(repo_root), | |
| env=env, | |
| capture_output=True, | |
| text=True, | |
| encoding="utf-8", | |
| errors="replace", | |
| ) | |
| if process.returncode != 0: | |
| stderr_text = process.stderr if process.stderr is not None else "" | |
| stdout_text = process.stdout if process.stdout is not None else "" | |
| combined_output = f"{stdout_text}\n{stderr_text}".strip() | |
| if "No module named modal" in combined_output: | |
| raise RuntimeError("Modal is not installed in this Python environment. Install it with: py -m pip install modal") | |
| raise RuntimeError(f"Modal deploy failed:\n{combined_output}") | |
| stdout_text = process.stdout if process.stdout is not None else "" | |
| if stdout_text: | |
| print(stdout_text[-4000:]) | |
| def _build_modal_config_from_args(args): | |
| config = {} | |
| excluded_keys = { | |
| "modal_token_id", | |
| "modal_token_secret", | |
| "modal_api_key", | |
| "modal_cli_credentials_provided", | |
| "rebuild_modal", | |
| "delete_modal_embeddings", | |
| } | |
| for key in args.__dict__: | |
| if key in excluded_keys: | |
| continue | |
| config[key] = _to_json_safe(args.__dict__[key]) | |
| config["replay_path"] = None | |
| return config | |
| def _save_modal_artifacts(result_payload, output_root, job_id): | |
| output_root_path = Path(output_root) | |
| job_dir = output_root_path / job_id | |
| job_dir.mkdir(parents=True, exist_ok=True) | |
| files_payload = result_payload["files"] if "files" in result_payload else {} | |
| for rel_path in files_payload: | |
| local_path = job_dir / Path(rel_path) | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(local_path, "w", encoding="utf-8") as file: | |
| file.write(files_payload[rel_path]) | |
| images_payload = result_payload["images"] if "images" in result_payload else {} | |
| for rel_path in images_payload: | |
| image_info = images_payload[rel_path] | |
| if "data" not in image_info: | |
| continue | |
| local_path = job_dir / Path(rel_path) | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| image_bytes = base64.b64decode(image_info["data"]) | |
| with open(local_path, "wb") as file: | |
| file.write(image_bytes) | |
| summary_path = job_dir / "modal_fetch_summary.json" | |
| with open(summary_path, "w", encoding="utf-8") as file: | |
| json.dump(result_payload, file, indent=2) | |
| return str(job_dir) | |
| def _coerce_modal_terminal_payload(remote_result): | |
| if isinstance(remote_result, dict): | |
| payload = dict(remote_result) | |
| if "status" not in payload: | |
| if "success" in payload and payload["success"]: | |
| payload["status"] = "SUCCESS" | |
| elif "success" in payload and not payload["success"]: | |
| payload["status"] = "FAILED" | |
| else: | |
| payload["status"] = "SUCCESS" | |
| return payload | |
| return {"status": "SUCCESS"} | |
| def _run_on_modal_cli(args): | |
| try: | |
| import modal | |
| except Exception as error: | |
| raise RuntimeError("Modal SDK is required for CLI remote execution. Install with: py -m pip install modal") from error | |
| app_name = "protify-backend" | |
| gpu_type = "A10" | |
| if "modal_gpu_type" in args.__dict__ and args.modal_gpu_type is not None: | |
| gpu_type = args.modal_gpu_type | |
| timeout_seconds = 86400 | |
| if "modal_timeout_seconds" in args.__dict__ and args.modal_timeout_seconds is not None: | |
| timeout_seconds = args.modal_timeout_seconds | |
| poll_interval_seconds = 5 | |
| if "modal_poll_interval_seconds" in args.__dict__ and args.modal_poll_interval_seconds is not None: | |
| poll_interval_seconds = args.modal_poll_interval_seconds | |
| log_tail_chars = 5000 | |
| if "modal_log_tail_chars" in args.__dict__ and args.modal_log_tail_chars is not None: | |
| log_tail_chars = args.modal_log_tail_chars | |
| max_stale_heartbeat_seconds = 600 | |
| if "modal_max_stale_heartbeat_seconds" in args.__dict__ and args.modal_max_stale_heartbeat_seconds is not None: | |
| max_stale_heartbeat_seconds = args.modal_max_stale_heartbeat_seconds | |
| artifacts_root = "modal_artifacts" | |
| if "modal_artifacts_dir" in args.__dict__ and args.modal_artifacts_dir is not None: | |
| artifacts_root = args.modal_artifacts_dir | |
| if args.rebuild_modal: | |
| print("Rebuilding Modal backend due to --rebuild_modal ...") | |
| _deploy_modal_backend(args) | |
| config = _build_modal_config_from_args(args) | |
| submit_fn = modal.Function.from_name(app_name, "submit_protify_job") | |
| status_fn = modal.Function.from_name(app_name, "get_job_status") | |
| log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta") | |
| results_fn = modal.Function.from_name(app_name, "get_results") | |
| delete_embeddings_fn = modal.Function.from_name(app_name, "delete_modal_embeddings") | |
| if args.delete_modal_embeddings: | |
| print("Deleting Modal embedding cache due to --delete_modal_embeddings ...") | |
| try: | |
| delete_embeddings_payload = delete_embeddings_fn.remote() | |
| except Exception: | |
| print("Modal embedding delete failed before app/function lookup succeeded; attempting deploy then retry...") | |
| _deploy_modal_backend(args) | |
| submit_fn = modal.Function.from_name(app_name, "submit_protify_job") | |
| status_fn = modal.Function.from_name(app_name, "get_job_status") | |
| log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta") | |
| results_fn = modal.Function.from_name(app_name, "get_results") | |
| delete_embeddings_fn = modal.Function.from_name(app_name, "delete_modal_embeddings") | |
| delete_embeddings_payload = delete_embeddings_fn.remote() | |
| if isinstance(delete_embeddings_payload, dict) and "message" in delete_embeddings_payload: | |
| print(delete_embeddings_payload["message"]) | |
| has_dataset_run = len(args.data_names) > 0 or len(args.data_dirs) > 0 | |
| if not has_dataset_run and not args.proteingym: | |
| return 0 | |
| try: | |
| submit_result = submit_fn.remote( | |
| config=config, | |
| gpu_type=gpu_type, | |
| hf_token=args.hf_token, | |
| wandb_api_key=args.wandb_api_key, | |
| synthyra_api_key=args.synthyra_api_key, | |
| timeout_seconds=timeout_seconds, | |
| ) | |
| except Exception: | |
| print("Modal submit failed before app/function lookup succeeded; attempting deploy then retry...") | |
| _deploy_modal_backend(args) | |
| submit_fn = modal.Function.from_name(app_name, "submit_protify_job") | |
| status_fn = modal.Function.from_name(app_name, "get_job_status") | |
| log_delta_fn = modal.Function.from_name(app_name, "get_job_log_delta") | |
| results_fn = modal.Function.from_name(app_name, "get_results") | |
| submit_result = submit_fn.remote( | |
| config=config, | |
| gpu_type=gpu_type, | |
| hf_token=args.hf_token, | |
| wandb_api_key=args.wandb_api_key, | |
| synthyra_api_key=args.synthyra_api_key, | |
| timeout_seconds=timeout_seconds, | |
| ) | |
| assert isinstance(submit_result, dict), "Modal submit response is not a dictionary." | |
| assert "job_id" in submit_result, "Modal submit response missing job_id." | |
| job_id = submit_result["job_id"] | |
| function_call_id = submit_result["function_call_id"] if "function_call_id" in submit_result else None | |
| print(f"Modal job submitted: {job_id}") | |
| if function_call_id is not None: | |
| print(f"Modal function call id: {function_call_id}") | |
| terminal_states = {"SUCCESS", "FAILED", "TERMINATED", "TIMEOUT"} | |
| final_status_payload = None | |
| poll_start_time = time.time() | |
| max_poll_seconds = int(timeout_seconds) + 900 | |
| status_print_interval_seconds = 15 | |
| last_status_print_time = 0.0 | |
| last_status_line = "" | |
| missing_status_count = 0 | |
| log_offset = 0 | |
| function_call = None | |
| if function_call_id is not None: | |
| function_call = modal.FunctionCall.from_id(function_call_id) | |
| def _emit_remote_logs(): | |
| nonlocal log_offset | |
| delta_payload = log_delta_fn.remote(job_id=job_id, offset=log_offset, max_chars=log_tail_chars) | |
| if isinstance(delta_payload, dict): | |
| if "next_offset" in delta_payload and isinstance(delta_payload["next_offset"], int): | |
| log_offset = delta_payload["next_offset"] | |
| if "chunk" in delta_payload and delta_payload["chunk"]: | |
| sys.stdout.write(delta_payload["chunk"]) | |
| sys.stdout.flush() | |
| while True: | |
| _emit_remote_logs() | |
| status_payload = status_fn.remote(job_id=job_id) | |
| assert isinstance(status_payload, dict), "Modal status response is not a dictionary." | |
| if "success" in status_payload and status_payload["success"]: | |
| missing_status_count = 0 | |
| status_value = status_payload["status"] if "status" in status_payload else "UNKNOWN" | |
| phase_value = status_payload["phase"] if "phase" in status_payload else "N/A" | |
| heartbeat_age = status_payload["heartbeat_age_seconds"] if "heartbeat_age_seconds" in status_payload else None | |
| heartbeat_text = "N/A" if heartbeat_age is None else f"{heartbeat_age:.1f}s" | |
| status_line = f"[Modal] status={status_value} phase={phase_value} heartbeat_age={heartbeat_text}" | |
| if status_value in terminal_states: | |
| final_status_payload = dict(status_payload) | |
| break | |
| else: | |
| missing_status_count += 1 | |
| status_line = "[Modal] state=queued_or_initializing" | |
| if missing_status_count % 6 == 0 and "error" in status_payload and status_payload["error"]: | |
| status_line = f"[Modal] state=queued_or_initializing detail={status_payload['error']}" | |
| now = time.time() | |
| if status_line != last_status_line or (now - last_status_print_time) >= status_print_interval_seconds: | |
| print(status_line) | |
| last_status_line = status_line | |
| last_status_print_time = now | |
| if function_call is not None: | |
| try: | |
| remote_result = function_call.get(timeout=0) | |
| final_status_payload = _coerce_modal_terminal_payload(remote_result) | |
| if "phase" not in final_status_payload and "phase" in status_payload: | |
| final_status_payload["phase"] = status_payload["phase"] | |
| break | |
| except TimeoutError: | |
| pass | |
| except Exception as error: | |
| final_status_payload = {"status": "FAILED", "error": f"Function call failed: {error}"} | |
| break | |
| elapsed_seconds = now - poll_start_time | |
| if elapsed_seconds > max_poll_seconds: | |
| final_status_payload = { | |
| "status": "TIMEOUT", | |
| "phase": "poll_timeout", | |
| "error": f"Polling exceeded timeout window ({max_poll_seconds} seconds).", | |
| } | |
| break | |
| if "success" in status_payload and status_payload["success"] and "heartbeat_age_seconds" in status_payload: | |
| heartbeat_age = status_payload["heartbeat_age_seconds"] | |
| if heartbeat_age is not None and heartbeat_age > max_stale_heartbeat_seconds and function_call is None: | |
| final_status_payload = { | |
| "status": "FAILED", | |
| "phase": "stale_heartbeat", | |
| "error": f"Heartbeat stale for {heartbeat_age:.1f}s with no function_call_id available.", | |
| } | |
| break | |
| time.sleep(max(1, int(poll_interval_seconds))) | |
| final_delta_payload = log_delta_fn.remote(job_id=job_id, offset=log_offset, max_chars=log_tail_chars * 8) | |
| if isinstance(final_delta_payload, dict): | |
| if "chunk" in final_delta_payload and final_delta_payload["chunk"]: | |
| sys.stdout.write(final_delta_payload["chunk"]) | |
| sys.stdout.flush() | |
| try: | |
| results_payload = results_fn.remote(job_id=job_id) | |
| except Exception as error: | |
| results_payload = {"success": False, "error": str(error)} | |
| if isinstance(results_payload, dict) and "success" in results_payload and results_payload["success"]: | |
| artifacts_dir = _save_modal_artifacts(results_payload, artifacts_root, job_id) | |
| print(f"Modal artifacts saved to {artifacts_dir}") | |
| if final_status_payload is None: | |
| final_status_payload = {"status": "FAILED", "error": "No terminal status was resolved."} | |
| final_status = final_status_payload["status"] if "status" in final_status_payload else "FAILED" | |
| if final_status != "SUCCESS": | |
| if "error" in final_status_payload and final_status_payload["error"]: | |
| print(f"Modal job failed: {final_status_payload['error']}") | |
| return 1 | |
| return 0 | |