| from __future__ import annotations |
| from copy import deepcopy |
| from typing import Any, Dict |
|
|
| import hydra |
| from langchain.tools import BaseTool |
|
|
| from flows.base_flows import AtomicFlow |
|
|
|
|
| class LCToolFlow(AtomicFlow): |
| REQUIRED_KEYS_CONFIG = ["backend"] |
|
|
| |
|
|
| SUPPORTS_CACHING: bool = False |
|
|
| backend: BaseTool |
|
|
| def __init__(self, backend: BaseTool, **kwargs) -> None: |
| super().__init__(**kwargs) |
| self.backend = backend |
| |
| @classmethod |
| def _set_up_backend(cls, config: Dict[str, Any]) -> BaseTool: |
| if config["_target_"].startswith("."): |
| |
| |
| |
| cls_parent_module = ".".join(cls.__module__.split(".")[:-1]) |
| config["_target_"] = cls_parent_module + config["_target_"] |
| tool = hydra.utils.instantiate(config, _convert_="partial") |
|
|
| return tool |
|
|
| @classmethod |
| def instantiate_from_config(cls, config: Dict[str, Any]) -> LCToolFlow: |
| flow_config = deepcopy(config) |
|
|
| kwargs = {"flow_config": flow_config} |
|
|
| |
| kwargs["backend"] = cls._set_up_backend(config["backend"]) |
|
|
| |
| return cls(**kwargs) |
|
|
| def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| observation = self.backend.run(tool_input=input_data) |
|
|
| return {"observation": observation} |
|
|
|
|