| | from typing import Any, List, Optional, Sequence, Tuple |
| |
|
| | from langchain_core._api import deprecated |
| | from langchain_core.agents import AgentAction |
| | from langchain_core.callbacks import BaseCallbackManager |
| | from langchain_core.language_models import BaseLanguageModel |
| | from langchain_core.prompts import BasePromptTemplate |
| | from langchain_core.prompts.chat import ( |
| | ChatPromptTemplate, |
| | HumanMessagePromptTemplate, |
| | SystemMessagePromptTemplate, |
| | ) |
| | from langchain_core.tools import BaseTool |
| | from pydantic import Field |
| |
|
| | from langchain._api.deprecation import AGENT_DEPRECATION_WARNING |
| | from langchain.agents.agent import Agent, AgentOutputParser |
| | from langchain.agents.chat.output_parser import ChatOutputParser |
| | from langchain.agents.chat.prompt import ( |
| | FORMAT_INSTRUCTIONS, |
| | HUMAN_MESSAGE, |
| | SYSTEM_MESSAGE_PREFIX, |
| | SYSTEM_MESSAGE_SUFFIX, |
| | ) |
| | from langchain.agents.utils import validate_tools_single_input |
| | from langchain.chains.llm import LLMChain |
| |
|
| |
|
| | @deprecated( |
| | "0.1.0", |
| | message=AGENT_DEPRECATION_WARNING, |
| | removal="1.0", |
| | ) |
| | class ChatAgent(Agent): |
| | """Chat Agent.""" |
| |
|
| | output_parser: AgentOutputParser = Field(default_factory=ChatOutputParser) |
| | """Output parser for the agent.""" |
| |
|
| | @property |
| | def observation_prefix(self) -> str: |
| | """Prefix to append the observation with.""" |
| | return "Observation: " |
| |
|
| | @property |
| | def llm_prefix(self) -> str: |
| | """Prefix to append the llm call with.""" |
| | return "Thought:" |
| |
|
| | def _construct_scratchpad( |
| | self, intermediate_steps: List[Tuple[AgentAction, str]] |
| | ) -> str: |
| | agent_scratchpad = super()._construct_scratchpad(intermediate_steps) |
| | if not isinstance(agent_scratchpad, str): |
| | raise ValueError("agent_scratchpad should be of type string.") |
| | if agent_scratchpad: |
| | return ( |
| | f"This was your previous work " |
| | f"(but I haven't seen any of it! I only see what " |
| | f"you return as final answer):\n{agent_scratchpad}" |
| | ) |
| | else: |
| | return agent_scratchpad |
| |
|
| | @classmethod |
| | def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: |
| | return ChatOutputParser() |
| |
|
| | @classmethod |
| | def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: |
| | super()._validate_tools(tools) |
| | validate_tools_single_input(class_name=cls.__name__, tools=tools) |
| |
|
| | @property |
| | def _stop(self) -> List[str]: |
| | return ["Observation:"] |
| |
|
| | @classmethod |
| | def create_prompt( |
| | cls, |
| | tools: Sequence[BaseTool], |
| | system_message_prefix: str = SYSTEM_MESSAGE_PREFIX, |
| | system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, |
| | human_message: str = HUMAN_MESSAGE, |
| | format_instructions: str = FORMAT_INSTRUCTIONS, |
| | input_variables: Optional[List[str]] = None, |
| | ) -> BasePromptTemplate: |
| | """Create a prompt from a list of tools. |
| | |
| | Args: |
| | tools: A list of tools. |
| | system_message_prefix: The system message prefix. |
| | Default is SYSTEM_MESSAGE_PREFIX. |
| | system_message_suffix: The system message suffix. |
| | Default is SYSTEM_MESSAGE_SUFFIX. |
| | human_message: The human message. Default is HUMAN_MESSAGE. |
| | format_instructions: The format instructions. |
| | Default is FORMAT_INSTRUCTIONS. |
| | input_variables: The input variables. Default is None. |
| | |
| | Returns: |
| | A prompt template. |
| | """ |
| |
|
| | tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) |
| | tool_names = ", ".join([tool.name for tool in tools]) |
| | format_instructions = format_instructions.format(tool_names=tool_names) |
| | template = "\n\n".join( |
| | [ |
| | system_message_prefix, |
| | tool_strings, |
| | format_instructions, |
| | system_message_suffix, |
| | ] |
| | ) |
| | messages = [ |
| | SystemMessagePromptTemplate.from_template(template), |
| | HumanMessagePromptTemplate.from_template(human_message), |
| | ] |
| | if input_variables is None: |
| | input_variables = ["input", "agent_scratchpad"] |
| | return ChatPromptTemplate(input_variables=input_variables, messages=messages) |
| |
|
| | @classmethod |
| | def from_llm_and_tools( |
| | cls, |
| | llm: BaseLanguageModel, |
| | tools: Sequence[BaseTool], |
| | callback_manager: Optional[BaseCallbackManager] = None, |
| | output_parser: Optional[AgentOutputParser] = None, |
| | system_message_prefix: str = SYSTEM_MESSAGE_PREFIX, |
| | system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, |
| | human_message: str = HUMAN_MESSAGE, |
| | format_instructions: str = FORMAT_INSTRUCTIONS, |
| | input_variables: Optional[List[str]] = None, |
| | **kwargs: Any, |
| | ) -> Agent: |
| | """Construct an agent from an LLM and tools. |
| | |
| | Args: |
| | llm: The language model. |
| | tools: A list of tools. |
| | callback_manager: The callback manager. Default is None. |
| | output_parser: The output parser. Default is None. |
| | system_message_prefix: The system message prefix. |
| | Default is SYSTEM_MESSAGE_PREFIX. |
| | system_message_suffix: The system message suffix. |
| | Default is SYSTEM_MESSAGE_SUFFIX. |
| | human_message: The human message. Default is HUMAN_MESSAGE. |
| | format_instructions: The format instructions. |
| | Default is FORMAT_INSTRUCTIONS. |
| | input_variables: The input variables. Default is None. |
| | kwargs: Additional keyword arguments. |
| | |
| | Returns: |
| | An agent. |
| | """ |
| | cls._validate_tools(tools) |
| | prompt = cls.create_prompt( |
| | tools, |
| | system_message_prefix=system_message_prefix, |
| | system_message_suffix=system_message_suffix, |
| | human_message=human_message, |
| | format_instructions=format_instructions, |
| | input_variables=input_variables, |
| | ) |
| | llm_chain = LLMChain( |
| | llm=llm, |
| | prompt=prompt, |
| | callback_manager=callback_manager, |
| | ) |
| | tool_names = [tool.name for tool in tools] |
| | _output_parser = output_parser or cls._get_default_output_parser() |
| | return cls( |
| | llm_chain=llm_chain, |
| | allowed_tools=tool_names, |
| | output_parser=_output_parser, |
| | **kwargs, |
| | ) |
| |
|
| | @property |
| | def _agent_type(self) -> str: |
| | raise ValueError |
| |
|