| from copy import deepcopy |
| import hydra |
|
|
| import time |
|
|
| from typing import Dict, Optional, Any |
|
|
| from flows.base_flows import AtomicFlow |
|
|
| from flows.utils import logging |
| from flows.messages.flow_message import UpdateMessage_ChatMessage |
|
|
| from flows.prompt_template import JinjaPrompt |
|
|
| from flows.backends.llm_lite import LiteLLMBackend |
|
|
| log = logging.get_logger(__name__) |
|
|
|
|
| class OpenAIChatAtomicFlow(AtomicFlow): |
| REQUIRED_KEYS_CONFIG = ["backend"] |
|
|
| SUPPORTS_CACHING: bool = True |
|
|
| system_message_prompt_template: JinjaPrompt |
| human_message_prompt_template: JinjaPrompt |
|
|
| backend: LiteLLMBackend |
| init_human_message_prompt_template: Optional[JinjaPrompt] = None |
|
|
| def __init__(self, |
| system_message_prompt_template, |
| human_message_prompt_template, |
| init_human_message_prompt_template, |
| backend, |
| **kwargs): |
| super().__init__(**kwargs) |
| self.system_message_prompt_template = system_message_prompt_template |
| self.human_message_prompt_template = human_message_prompt_template |
| self.init_human_message_prompt_template = init_human_message_prompt_template |
| self.backend = backend |
| assert self.flow_config["name"] not in [ |
| "system", |
| "user", |
| "assistant", |
| ], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" |
|
|
| def set_up_flow_state(self): |
| super().set_up_flow_state() |
| self.flow_state["previous_messages"] = [] |
|
|
| @classmethod |
| def _set_up_prompts(cls, config): |
| kwargs = {} |
| |
| kwargs["system_message_prompt_template"] = \ |
| hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") |
| kwargs["init_human_message_prompt_template"] = \ |
| hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial") |
| kwargs["human_message_prompt_template"] = \ |
| hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") |
|
|
| return kwargs |
| |
| @classmethod |
| def _set_up_backend(cls, config): |
| kwargs = {} |
|
|
| kwargs["backend"] = \ |
| hydra.utils.instantiate(config['backend'], _convert_="partial") |
| |
| return kwargs |
|
|
| @classmethod |
| def instantiate_from_config(cls, config): |
| flow_config = deepcopy(config) |
|
|
| kwargs = {"flow_config": flow_config} |
|
|
| |
| kwargs.update(cls._set_up_prompts(flow_config)) |
| kwargs.update(cls._set_up_backend(flow_config)) |
|
|
| |
| return cls(**kwargs) |
|
|
| def _is_conversation_initialized(self): |
| if len(self.flow_state["previous_messages"]) > 0: |
| return True |
|
|
| return False |
|
|
| def get_interface_description(self): |
| if self._is_conversation_initialized(): |
|
|
| return {"input": self.flow_config["input_interface_initialized"], |
| "output": self.flow_config["output_interface"]} |
| else: |
| return {"input": self.flow_config["input_interface_non_initialized"], |
| "output": self.flow_config["output_interface"]} |
|
|
| @staticmethod |
| def _get_message(prompt_template, input_data: Dict[str, Any]): |
| template_kwargs = {} |
| for input_variable in prompt_template.input_variables: |
| template_kwargs[input_variable] = input_data[input_variable] |
| msg_content = prompt_template.format(**template_kwargs) |
| return msg_content |
|
|
| def _get_demonstration_query_message_content(self, sample_data: Dict): |
| input_variables = self.init_human_message_prompt_template.input_variables |
| return self.init_human_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
|
|
|
|
| def _add_demonstrations(self, input_data: Dict[str, Any]): |
| |
| for demonstration in input_data.get("demonstrations",[]): |
| query = demonstration["query"] |
| response = demonstration["response"] |
| |
| self._state_update_add_chat_message(content=query, |
| role=self.flow_config["user_name"]) |
|
|
| self._state_update_add_chat_message(content=response, |
| role=self.flow_config["assistant_name"]) |
|
|
| def _state_update_add_chat_message(self, |
| role: str, |
| content: str) -> None: |
|
|
|
|
| acceptable_roles = [self.flow_config["system_name"],self.flow_config["user_name"],self.flow_config["assistant_name"]] |
| if role in acceptable_roles: |
| self.flow_state["previous_messages"].append({"role": role , "content": content}) |
| |
| else: |
| raise Exception(f"Invalid role: `{role}`.\n" |
| f"Role should be one of: " |
| f"`{acceptable_roles}`, ") |
| |
|
|
| |
| chat_message = UpdateMessage_ChatMessage( |
| created_by=self.flow_config["name"], |
| updated_flow=self.flow_config["name"], |
| role=role, |
| content=content, |
| ) |
| self._log_message(chat_message) |
|
|
| def _get_previous_messages(self): |
| all_messages = self.flow_state["previous_messages"] |
| first_k = self.flow_config["previous_messages"]["first_k"] |
| last_k = self.flow_config["previous_messages"]["last_k"] |
|
|
| if not first_k and not last_k: |
| return all_messages |
| elif first_k and last_k: |
| return all_messages[:first_k] + all_messages[-last_k:] |
| elif first_k: |
| return all_messages[:first_k] |
| return all_messages[-last_k:] |
|
|
| def _call(self): |
| |
| messages = self._get_previous_messages() |
| _success = False |
| attempts = 1 |
| error = None |
| response = None |
| while attempts <= self.flow_config['n_api_retries']: |
| try: |
| response = self.backend(messages=messages,mock_response=False) |
| response = [ answer["content"] for answer in response] |
| _success = True |
| break |
| except Exception as e: |
| log.error( |
| f"Error {attempts} in calling backend: {e}. " |
| f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..." |
| ) |
| |
| |
| |
| |
| attempts += 1 |
| time.sleep(self.flow_config['wait_time_between_retries']) |
| error = e |
| |
| if not _success: |
| raise error |
|
|
| return response |
|
|
| def _initialize_conversation(self, input_data: Dict[str, Any]): |
| |
| system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
|
|
| self._state_update_add_chat_message(content=system_message_content, |
| role=self.flow_config["system_name"]) |
|
|
| |
| self._add_demonstrations(input_data) |
|
|
| def _process_input(self, input_data: Dict[str, Any]): |
| if self._is_conversation_initialized(): |
| |
| user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
|
|
| else: |
| |
| self._initialize_conversation(input_data) |
| if getattr(self, "init_human_message_prompt_template", None) is not None: |
| |
| user_message_content = self._get_message(self.init_human_message_prompt_template, input_data) |
| else: |
| user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
|
|
| self._state_update_add_chat_message(role=self.flow_config["user_name"], |
| content=user_message_content) |
|
|
| def run(self, |
| input_data: Dict[str, Any]) -> Dict[str, Any]: |
| |
| self._process_input(input_data) |
|
|
| |
| response = self._call() |
| |
| |
| for answer in response: |
| self._state_update_add_chat_message( |
| role=self.flow_config["assistant_name"], |
| content=answer |
| ) |
| response = response if len(response) > 1 or len(response) == 0 else response[0] |
| return {"api_output": response} |
|
|