repl_env-0.2.2 / recursive_backends.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
7031bab verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Recursive backend abstractions for repl_env.
This module keeps direct LM calls and recursive child spawning out of the
runner and environment. The runner owns the iterative loop; the backend owns
query/query_batched/child-recursion behavior.
"""
from __future__ import annotations
import threading
import time
from concurrent.futures import as_completed, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Callable, Protocol
ChatFn = Callable[..., str]
class RecursiveBackend(Protocol):
max_depth: int
depth: int
child_traces: list["ChildTrace"]
def query(self, prompt: str, model: str | None = None) -> str: ...
def query_batched(
self, prompts: list[str], model: str | None = None
) -> list[str]: ...
def recursive_query(self, prompt: str, model: str | None = None) -> str: ...
def recursive_query_batched(
self, prompts: list[str], model: str | None = None
) -> list[str]: ...
@dataclass
class BackendLimits:
max_depth: int = 1
max_batch_workers: int = 8
max_children_total: int | None = None
max_children_per_batch: int | None = None
result_truncation_limit: int | None = None
# Cooperative timeout: checked between iterations, not during LLM calls.
# A slow LLM call within an iteration will not be interrupted — the timeout
# fires at the next iteration boundary. For mid-call cancellation, use
# process-based isolation instead.
per_child_timeout_s: float | None = None
# Tree-global child counter shared across all recursion depths
_children_spawned: int = field(default=0, init=False, repr=False)
_children_lock: threading.Lock = field(
default_factory=threading.Lock, init=False, repr=False
)
@dataclass
class ChildTrace:
depth: int
duration_s: float
prompt_preview: str
result_preview: str | None
error: str | None
class DirectLMBackend:
"""Direct LM backend with no child recursion beyond fallback to itself."""
def __init__(
self,
llm_chat_fn: ChatFn,
*,
depth: int = 0,
limits: BackendLimits | None = None,
) -> None:
self.llm_chat_fn = llm_chat_fn
self.depth = depth
self.limits = limits or BackendLimits()
self.max_depth = self.limits.max_depth
self.child_traces: list[ChildTrace] = []
def query(self, prompt: str, model: str | None = None) -> str:
try:
result = self.llm_chat_fn([{"role": "user", "content": prompt}], model)
except TypeError:
result = self.llm_chat_fn([{"role": "user", "content": prompt}])
return self._truncate(result)
def query_batched(self, prompts: list[str], model: str | None = None) -> list[str]:
if not prompts:
return []
max_workers = min(len(prompts), self.limits.max_batch_workers)
results: list[str] = [""] * len(prompts)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(self.query, prompt, model): idx
for idx, prompt in enumerate(prompts)
}
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
results[idx] = future.result()
except Exception as exc:
results[idx] = f"Error: {exc}"
return results
def recursive_query(self, prompt: str, model: str | None = None) -> str:
return self.query(prompt, model)
def recursive_query_batched(
self, prompts: list[str], model: str | None = None
) -> list[str]:
return self.query_batched(prompts, model)
def _truncate(self, result: str) -> str:
limit = self.limits.result_truncation_limit
if limit is not None and len(result) > limit:
return result[:limit]
return result
class LocalChildRLMBackend(DirectLMBackend):
"""Recursive backend that spawns child LocalRLMRunner instances."""
def __init__(
self,
llm_chat_fn: ChatFn,
*,
runner_factory: Callable[..., object],
system_prompt: str,
max_iterations: int,
env_max_iterations_multiplier: int,
depth: int = 0,
limits: BackendLimits | None = None,
on_subcall_start: Callable[[int, str, str], None] | None = None,
on_subcall_complete: Callable[[int, str, float, str | None], None]
| None = None,
) -> None:
super().__init__(llm_chat_fn, depth=depth, limits=limits)
self.runner_factory = runner_factory
self.system_prompt = system_prompt
self.max_iterations = max_iterations
self.env_max_iterations_multiplier = env_max_iterations_multiplier
self.on_subcall_start = on_subcall_start
self.on_subcall_complete = on_subcall_complete
def recursive_query(self, prompt: str, model: str | None = None) -> str:
next_depth = self.depth + 1
if next_depth >= self.max_depth:
return self.query(prompt, model)
with self.limits._children_lock:
if self.limits.max_children_total is not None:
if self.limits._children_spawned >= self.limits.max_children_total:
return "Error: max_children_total exceeded"
self.limits._children_spawned += 1
start = time.perf_counter()
error: str | None = None
result_text = ""
resolved_model = model or "default"
if self.on_subcall_start is not None:
try:
self.on_subcall_start(next_depth, str(resolved_model), prompt[:80])
except Exception:
pass
try:
child = self.runner_factory(
self.llm_chat_fn,
system_prompt=self.system_prompt,
max_iterations=self.max_iterations,
max_depth=self.max_depth,
depth=next_depth,
env_max_iterations_multiplier=self.env_max_iterations_multiplier,
max_batch_workers=self.limits.max_batch_workers,
backend_factory=self._child_backend_factory,
on_subcall_start=self.on_subcall_start,
on_subcall_complete=self.on_subcall_complete,
)
result = child.run(
prompt, prompt, model=model, timeout_s=self.limits.per_child_timeout_s
)
result_text = self._truncate(result.final_answer or "")
return result_text
except Exception as exc:
error = str(exc)
raise
finally:
duration = time.perf_counter() - start
self.child_traces.append(
ChildTrace(
depth=next_depth,
duration_s=duration,
prompt_preview=prompt[:80],
result_preview=(result_text[:80] if result_text else None),
error=error,
)
)
if self.on_subcall_complete is not None:
try:
self.on_subcall_complete(
next_depth,
str(resolved_model),
duration,
error,
)
except Exception:
pass
def recursive_query_batched(
self, prompts: list[str], model: str | None = None
) -> list[str]:
if not prompts:
return []
batch_limit = self.limits.max_children_per_batch
if batch_limit is not None:
prompts = prompts[:batch_limit]
max_workers = min(len(prompts), self.limits.max_batch_workers)
results: list[str] = [""] * len(prompts)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(self.recursive_query, prompt, model): idx
for idx, prompt in enumerate(prompts)
}
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
results[idx] = future.result()
except Exception as exc:
results[idx] = f"Error: {exc}"
return results
def _child_backend_factory(
self, llm_chat_fn: ChatFn, **kwargs
) -> "LocalChildRLMBackend":
return LocalChildRLMBackend(
llm_chat_fn,
runner_factory=self.runner_factory,
system_prompt=kwargs["system_prompt"],
max_iterations=kwargs["max_iterations"],
env_max_iterations_multiplier=kwargs["env_max_iterations_multiplier"],
depth=kwargs["depth"],
limits=self.limits,
on_subcall_start=self.on_subcall_start,
on_subcall_complete=self.on_subcall_complete,
)