Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Local recursive RLM runner for repl_env. | |
| This keeps the iterative prompting/orchestration layer outside the environment, | |
| following the same separation used by the official RLM implementation and DSPy: | |
| - `REPLEnvironment` executes code and exposes tools | |
| - `LocalRLMRunner` owns prompting, message history, and recursive child runs | |
| """ | |
| from __future__ import annotations | |
| import re | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Callable | |
| from .local import LocalREPLEnv | |
| from .prompts import ( | |
| build_rlm_system_prompt, | |
| build_user_prompt, | |
| extract_code_blocks, | |
| format_observations, | |
| QueryMetadata, | |
| RLM_SYSTEM_PROMPT, | |
| ) | |
| from .recursive_backends import BackendLimits, LocalChildRLMBackend, RecursiveBackend | |
| ChatFn = Callable[..., str] | |
| class RLMRunResult: | |
| final_answer: str | None | |
| messages: list[dict[str, str]] | |
| iterations: int | |
| depth: int | |
| child_traces: list[object] | |
| class LocalRLMRunner: | |
| """Local recursive RLM orchestrator built on top of LocalREPLEnv.""" | |
| def __init__( | |
| self, | |
| llm_chat_fn: ChatFn, | |
| *, | |
| system_prompt: str = RLM_SYSTEM_PROMPT, | |
| max_iterations: int = 30, | |
| max_depth: int = 2, | |
| depth: int = 0, | |
| env_max_iterations_multiplier: int = 5, | |
| max_batch_workers: int = 8, | |
| backend_factory: Callable[..., RecursiveBackend] | None = None, | |
| max_children_total: int | None = None, | |
| max_children_per_batch: int | None = None, | |
| result_truncation_limit: int | None = None, | |
| per_child_timeout_s: float | None = None, | |
| on_subcall_start: Callable[[int, str, str], None] | None = None, | |
| on_subcall_complete: Callable[[int, str, float, str | None], None] | |
| | None = None, | |
| verbose: bool = False, | |
| ) -> None: | |
| self.llm_chat_fn = llm_chat_fn | |
| self.system_prompt = system_prompt | |
| self.max_iterations = max_iterations | |
| self.max_depth = max_depth | |
| self.depth = depth | |
| self.env_max_iterations_multiplier = env_max_iterations_multiplier | |
| self.max_batch_workers = max_batch_workers | |
| self.backend_factory = backend_factory or self._default_backend_factory | |
| self.max_children_total = max_children_total | |
| self.max_children_per_batch = max_children_per_batch | |
| self.result_truncation_limit = result_truncation_limit | |
| self.per_child_timeout_s = per_child_timeout_s | |
| self.on_subcall_start = on_subcall_start | |
| self.on_subcall_complete = on_subcall_complete | |
| self.verbose = verbose | |
| def _default_backend_factory( | |
| self, llm_chat_fn: ChatFn, **kwargs | |
| ) -> RecursiveBackend: | |
| limits = BackendLimits( | |
| max_depth=self.max_depth, | |
| max_batch_workers=self.max_batch_workers, | |
| max_children_total=self.max_children_total, | |
| max_children_per_batch=self.max_children_per_batch, | |
| result_truncation_limit=self.result_truncation_limit, | |
| per_child_timeout_s=self.per_child_timeout_s, | |
| ) | |
| return LocalChildRLMBackend( | |
| llm_chat_fn, | |
| runner_factory=LocalRLMRunner, | |
| system_prompt=kwargs["system_prompt"], | |
| max_iterations=kwargs["max_iterations"], | |
| env_max_iterations_multiplier=kwargs["env_max_iterations_multiplier"], | |
| depth=kwargs["depth"], | |
| limits=limits, | |
| on_subcall_start=self.on_subcall_start, | |
| on_subcall_complete=self.on_subcall_complete, | |
| ) | |
| def run( | |
| self, | |
| context: str, | |
| task_prompt: str, | |
| *, | |
| model: str | None = None, | |
| timeout_s: float | None = None, | |
| ) -> RLMRunResult: | |
| backend = self.backend_factory( | |
| self.llm_chat_fn, | |
| system_prompt=self.system_prompt, | |
| max_iterations=self.max_iterations, | |
| max_depth=self.max_depth, | |
| depth=self.depth, | |
| env_max_iterations_multiplier=self.env_max_iterations_multiplier, | |
| ) | |
| with LocalREPLEnv( | |
| llm_query_fn=backend.query, | |
| llm_batch_fn=backend.query_batched, | |
| subcall_fn=backend.recursive_query, | |
| subcall_batch_fn=backend.recursive_query_batched, | |
| ) as env: | |
| result = env.reset( | |
| context=context, | |
| task_prompt=task_prompt, | |
| max_iterations=self.max_iterations * self.env_max_iterations_multiplier, | |
| llm_model=model, | |
| ) | |
| obs = result.observation | |
| query_metadata = QueryMetadata( | |
| context_lengths=[obs.context_length], | |
| context_total_length=obs.context_length, | |
| context_type="str", | |
| ) | |
| messages = build_rlm_system_prompt(self.system_prompt, query_metadata) | |
| messages.append(build_user_prompt(root_prompt=task_prompt, iteration=0)) | |
| run_start = time.perf_counter() | |
| for iteration in range(1, self.max_iterations + 1): | |
| # Cooperative timeout check (matches official RLM pattern) | |
| if timeout_s is not None: | |
| elapsed = time.perf_counter() - run_start | |
| if elapsed >= timeout_s: | |
| return RLMRunResult( | |
| final_answer=f"Error: child timeout after {elapsed:.3f}s", | |
| messages=messages, | |
| iterations=iteration - 1, | |
| depth=self.depth, | |
| child_traces=list(getattr(backend, "child_traces", [])), | |
| ) | |
| response = self._chat(messages, model) | |
| code_blocks = extract_code_blocks(response) | |
| code_block_observations = [] | |
| if self.verbose: | |
| print( | |
| f"[depth={self.depth}] iteration={iteration} code_blocks={len(code_blocks)}" | |
| ) | |
| if not code_blocks: | |
| messages.append({"role": "assistant", "content": response}) | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": ( | |
| "Please continue by writing Python code in ```repl``` blocks, " | |
| "or submit the final answer with FINAL(...) / FINAL_VAR(...)." | |
| ), | |
| } | |
| ) | |
| continue | |
| for code in code_blocks: | |
| result = env.execute(code) | |
| code_block_observations.append(result.observation) | |
| # Check for FINAL after all blocks executed (matches official RLM). | |
| # The model expects all blocks to run — it often writes exploration | |
| # code first and FINAL last in the same response. | |
| if any(obs.done for obs in code_block_observations): | |
| return RLMRunResult( | |
| final_answer=env.state().final_answer, | |
| messages=messages | |
| + [{"role": "assistant", "content": response}], | |
| iterations=iteration, | |
| depth=self.depth, | |
| child_traces=list(getattr(backend, "child_traces", [])), | |
| ) | |
| observation_text = format_observations( | |
| code_block_observations, code_blocks=code_blocks | |
| ) | |
| next_prompt = build_user_prompt( | |
| root_prompt=task_prompt, | |
| iteration=iteration, | |
| ) | |
| messages.append({"role": "assistant", "content": response}) | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": observation_text + "\n\n" + next_prompt["content"], | |
| } | |
| ) | |
| # Max iterations exhausted — give the model one final chance to answer | |
| final_answer = env.state().final_answer | |
| if final_answer is None: | |
| final_answer = self._default_answer(messages, model) | |
| return RLMRunResult( | |
| final_answer=final_answer, | |
| messages=messages, | |
| iterations=self.max_iterations, | |
| depth=self.depth, | |
| child_traces=list(getattr(backend, "child_traces", [])), | |
| ) | |
| def _default_answer( | |
| self, messages: list[dict[str, str]], model: str | None = None | |
| ) -> str | None: | |
| """Make one final LLM call asking for an answer when iterations are exhausted.""" | |
| final_prompt = messages + [ | |
| { | |
| "role": "user", | |
| "content": ( | |
| "You have run out of REPL iterations. Based on all your work above, " | |
| "provide your best final answer now. Use FINAL(your answer) to submit it. " | |
| "If you stored the answer in a variable, use FINAL_VAR(variable_name) instead. " | |
| "Do not write any more code — just provide the final answer." | |
| ), | |
| } | |
| ] | |
| try: | |
| response = self._chat(final_prompt, model) | |
| # Try to extract FINAL(...) from the response | |
| match = re.search(r"FINAL\((.*?)\)", response, re.DOTALL) | |
| if match: | |
| return match.group(1).strip() | |
| # If no FINAL pattern, return the raw response as best-effort | |
| return response.strip() if response.strip() else None | |
| except Exception: | |
| return None | |
| def _chat(self, messages: list[dict[str, str]], model: str | None = None) -> str: | |
| try: | |
| return self.llm_chat_fn(messages, model) | |
| except TypeError: | |
| return self.llm_chat_fn(messages) | |