File size: 2,722 Bytes
7031bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Controller composition for recursive backends.

This keeps server-side recursion assembly outside `REPLEnvironment`:
- backend selection based on max_depth
- limits configuration
- uniform callable interface for the environment
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable

from .prompts import RLM_SYSTEM_PROMPT
from .recursive_backends import BackendLimits, DirectLMBackend, LocalChildRLMBackend


ChatFn = Callable[..., str]


@dataclass
class RecursiveController:
    llm_query_fn: Callable[[str, str | None], str]
    llm_batch_fn: Callable[[list[str], str | None], list[str]]
    rlm_query_fn: Callable[[str, str | None], str] | None
    rlm_batch_fn: Callable[[list[str], str | None], list[str]] | None
    backend: object

    def close(self) -> None:
        pass


def create_server_recursive_controller(
    chat_fn: ChatFn,
    *,
    max_depth: int,
    max_iterations: int,
    system_prompt: str = RLM_SYSTEM_PROMPT,
    max_batch_workers: int = 8,
    max_children_total: int | None = None,
    max_children_per_batch: int | None = None,
    result_truncation_limit: int | None = None,
    per_child_timeout_s: float | None = None,
    env_max_iterations_multiplier: int = 5,
) -> RecursiveController:
    limits = BackendLimits(
        max_depth=max_depth,
        max_batch_workers=max_batch_workers,
        max_children_total=max_children_total,
        max_children_per_batch=max_children_per_batch,
        result_truncation_limit=result_truncation_limit,
        per_child_timeout_s=per_child_timeout_s,
    )
    if max_depth > 1:
        from .runner import LocalRLMRunner

        backend = LocalChildRLMBackend(
            chat_fn,
            runner_factory=LocalRLMRunner,
            system_prompt=system_prompt,
            max_iterations=max_iterations,
            env_max_iterations_multiplier=env_max_iterations_multiplier,
            depth=0,
            limits=limits,
        )
        return RecursiveController(
            llm_query_fn=backend.query,
            llm_batch_fn=backend.query_batched,
            rlm_query_fn=backend.recursive_query,
            rlm_batch_fn=backend.recursive_query_batched,
            backend=backend,
        )

    backend = DirectLMBackend(chat_fn, depth=0, limits=limits)
    return RecursiveController(
        llm_query_fn=backend.query,
        llm_batch_fn=backend.query_batched,
        rlm_query_fn=None,
        rlm_batch_fn=None,
        backend=backend,
    )