| | from copy import deepcopy |
| | import hydra |
| |
|
| | import time |
| |
|
| | from typing import Dict, Optional, Any |
| |
|
| | from flows.base_flows import AtomicFlow |
| | from flows.datasets import GenericDemonstrationsDataset |
| |
|
| | from flows.utils import logging |
| | from flows.messages.flow_message import UpdateMessage_ChatMessage |
| |
|
| | from flows.prompt_template import JinjaPrompt |
| |
|
| | from backends.llm_lite import LiteLLMBackend |
| |
|
| | log = logging.get_logger(__name__) |
| |
|
| |
|
| | class OpenAIChatAtomicFlow(AtomicFlow): |
| | REQUIRED_KEYS_CONFIG = ["backend"] |
| |
|
| | SUPPORTS_CACHING: bool = True |
| |
|
| | system_message_prompt_template: JinjaPrompt |
| | human_message_prompt_template: JinjaPrompt |
| |
|
| | backend: LiteLLMBackend |
| | init_human_message_prompt_template: Optional[JinjaPrompt] = None |
| | demonstrations: GenericDemonstrationsDataset = None |
| | demonstrations_k: Optional[int] = None |
| | demonstrations_response_prompt_template: str = None |
| |
|
| | def __init__(self, |
| | system_message_prompt_template, |
| | human_message_prompt_template, |
| | init_human_message_prompt_template, |
| | backend, |
| | demonstrations_response_prompt_template=None, |
| | demonstrations=None, |
| | **kwargs): |
| | super().__init__(**kwargs) |
| | self.system_message_prompt_template = system_message_prompt_template |
| | self.human_message_prompt_template = human_message_prompt_template |
| | self.init_human_message_prompt_template = init_human_message_prompt_template |
| | self.demonstrations_response_prompt_template = demonstrations_response_prompt_template |
| | self.demonstrations = demonstrations |
| | self.demonstrations_k = self.flow_config.get("demonstrations_k", None) |
| | self.backend = backend |
| | assert self.flow_config["name"] not in [ |
| | "system", |
| | "user", |
| | "assistant", |
| | ], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" |
| |
|
| | def set_up_flow_state(self): |
| | super().set_up_flow_state() |
| | self.flow_state["previous_messages"] = [] |
| |
|
| | @classmethod |
| | def _set_up_prompts(cls, config): |
| | kwargs = {} |
| | |
| | kwargs["system_message_prompt_template"] = \ |
| | hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") |
| | kwargs["init_human_message_prompt_template"] = \ |
| | hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial") |
| | kwargs["human_message_prompt_template"] = \ |
| | hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") |
| | |
| | if "demonstrations_response_prompt_template" in config: |
| | kwargs["demonstrations_response_prompt_template"] = \ |
| | hydra.utils.instantiate(config['demonstrations_response_prompt_template'], _convert_="partial") |
| | kwargs["demonstrations"] = GenericDemonstrationsDataset(**config['demonstrations']) |
| |
|
| | return kwargs |
| | |
| | @classmethod |
| | def _set_up_backend(cls, config): |
| | kwargs = {} |
| |
|
| | kwargs["backend"] = \ |
| | hydra.utils.instantiate(config['backend'], _convert_="partial") |
| | |
| | return kwargs |
| |
|
| | @classmethod |
| | def instantiate_from_config(cls, config): |
| | flow_config = deepcopy(config) |
| |
|
| | kwargs = {"flow_config": flow_config} |
| |
|
| | |
| | kwargs.update(cls._set_up_prompts(flow_config)) |
| | kwargs.update(cls._set_up_backend(flow_config)) |
| |
|
| | |
| | return cls(**kwargs) |
| |
|
| | def _is_conversation_initialized(self): |
| | if len(self.flow_state["previous_messages"]) > 0: |
| | return True |
| |
|
| | return False |
| |
|
| | def get_interface_description(self): |
| | if self._is_conversation_initialized(): |
| |
|
| | return {"input": self.flow_config["input_interface_initialized"], |
| | "output": self.flow_config["output_interface"]} |
| | else: |
| | return {"input": self.flow_config["input_interface_non_initialized"], |
| | "output": self.flow_config["output_interface"]} |
| |
|
| | @staticmethod |
| | def _get_message(prompt_template, input_data: Dict[str, Any]): |
| | template_kwargs = {} |
| | for input_variable in prompt_template.input_variables: |
| | template_kwargs[input_variable] = input_data[input_variable] |
| | msg_content = prompt_template.format(**template_kwargs) |
| | return msg_content |
| |
|
| | def _get_demonstration_query_message_content(self, sample_data: Dict): |
| | input_variables = self.init_human_message_prompt_template.input_variables |
| | return self.init_human_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
| |
|
| | def _get_demonstration_response_message_content(self, sample_data: Dict): |
| | input_variables = self.demonstrations_response_prompt_template.input_variables |
| | return self.demonstrations_response_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
| |
|
| | def _add_demonstrations(self): |
| | if self.demonstrations is not None: |
| | demonstrations = self.demonstrations |
| |
|
| | c = 0 |
| | for example in demonstrations: |
| | if self.demonstrations_k is not None and c >= self.demonstrations_k: |
| | break |
| | c += 1 |
| | query = self._get_demonstration_query_message_content(example) |
| | response = self._get_demonstration_response_message_content(example) |
| |
|
| | self._state_update_add_chat_message(content=query, |
| | role=self.flow_config["user_name"]) |
| |
|
| | self._state_update_add_chat_message(content=response, |
| | role=self.flow_config["assistant_name"]) |
| |
|
| | def _state_update_add_chat_message(self, |
| | role: str, |
| | content: str) -> None: |
| |
|
| |
|
| | acceptable_roles = [self.flow_config["system_name"],self.flow_config["user_name"],self.flow_config["assistant_name"]] |
| | if role in acceptable_roles: |
| | self.flow_state["previous_messages"].append({"role": role , "content": content}) |
| | |
| | else: |
| | raise Exception(f"Invalid role: `{role}`.\n" |
| | f"Role should be one of: " |
| | f"`{acceptable_roles}`, ") |
| | |
| |
|
| | |
| | chat_message = UpdateMessage_ChatMessage( |
| | created_by=self.flow_config["name"], |
| | updated_flow=self.flow_config["name"], |
| | role=role, |
| | content=content, |
| | ) |
| | self._log_message(chat_message) |
| |
|
| | def _get_previous_messages(self): |
| | all_messages = self.flow_state["previous_messages"] |
| | first_k = self.flow_config["previous_messages"]["first_k"] |
| | last_k = self.flow_config["previous_messages"]["last_k"] |
| |
|
| | if not first_k and not last_k: |
| | return all_messages |
| | elif first_k and last_k: |
| | return all_messages[:first_k] + all_messages[-last_k:] |
| | elif first_k: |
| | return all_messages[:first_k] |
| | return all_messages[-last_k:] |
| |
|
| | def _call(self): |
| | |
| | messages = self._get_previous_messages() |
| | _success = False |
| | attempts = 1 |
| | error = None |
| | response = None |
| | while attempts <= self.flow_config['n_api_retries']: |
| | try: |
| | response = self.backend(messages=messages,mock_response=False) |
| | response = [ answer["content"] for answer in response] |
| | _success = True |
| | break |
| | except Exception as e: |
| | log.error( |
| | f"Error {attempts} in calling backend: {e}. " |
| | f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..." |
| | ) |
| | |
| | |
| | |
| | |
| | attempts += 1 |
| | time.sleep(self.flow_config['wait_time_between_retries']) |
| | error = e |
| | |
| | if not _success: |
| | raise error |
| |
|
| | return response |
| |
|
| | def _initialize_conversation(self, input_data: Dict[str, Any]): |
| | |
| | system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
| |
|
| | self._state_update_add_chat_message(content=system_message_content, |
| | role=self.flow_config["system_name"]) |
| |
|
| | |
| | self._add_demonstrations() |
| |
|
| | def _process_input(self, input_data: Dict[str, Any]): |
| | if self._is_conversation_initialized(): |
| | |
| | user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
| |
|
| | else: |
| | |
| | self._initialize_conversation(input_data) |
| | if getattr(self, "init_human_message_prompt_template", None) is not None: |
| | |
| | user_message_content = self._get_message(self.init_human_message_prompt_template, input_data) |
| | else: |
| | user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
| |
|
| | self._state_update_add_chat_message(role=self.flow_config["user_name"], |
| | content=user_message_content) |
| |
|
| | def run(self, |
| | input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | |
| | self._process_input(input_data) |
| |
|
| | |
| | response = self._call() |
| | |
| | |
| | for answer in response: |
| | self._state_update_add_chat_message( |
| | role=self.flow_config["assistant_name"], |
| | content=answer |
| | ) |
| |
|
| | return {"api_output": response} |
| |
|