| import gradio as gr |
| from huggingface_hub import InferenceClient |
|
|
|
|
| SYSTEM_MESSAGE_DEFAULT = "You are a friendly Chatbot." |
| MAX_TOKENS_DEFAULT = 512 |
| TEMPERATURE_DEFAULT = 0.7 |
| TOP_P_DEFAULT = 0.95 |
|
|
| inference_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
|
| def respond( |
| user_message: str, |
| conversation_history: list[tuple[str, str]], |
| system_message: str, |
| max_tokens: int, |
| temperature: float, |
| top_p: float, |
| ): |
| """ |
| Respond to a user message given the conversation history and other parameters. |
| |
| Args: |
| user_message (str): The user's message. |
| conversation_history (list[tuple[str, str]]): The conversation history. |
| system_message (str): The system message to display at the top of the chat interface. |
| max_tokens (int): The maximum number of tokens to generate in the response. |
| temperature (float): The temperature to use when generating text. |
| top_p (float): The top-p value to use when generating text. |
| |
| Yields: |
| list[tuple[str, str]]: Updated conversation history with the new assistant response. |
| """ |
| messages = [{"role": "system", "content": system_message}] |
| |
| for user_input, assistant_response in conversation_history: |
| if user_input: |
| messages.append({"role": "user", "content": user_input}) |
| if assistant_response: |
| messages.append({"role": "assistant", "content": assistant_response}) |
|
|
| |
| messages.append({"role": "user", "content": user_message}) |
|
|
| |
| response = "" |
|
|
| |
| for message in inference_client.chat_completion( |
| messages, |
| max_tokens=max_tokens, |
| stream=True, |
| temperature=temperature, |
| top_p=top_p, |
| ): |
| token = message.choices[0].delta.content |
| response += token |
| |
| updated_history = conversation_history + [(user_message, response)] |
| yield updated_history |
|
|
|
|
| |
| chatbot_interface = gr.ChatInterface( |
| fn=respond, |
| chatbot=gr.Chatbot(height=600), |
| additional_inputs=[ |
| gr.Textbox( |
| value=SYSTEM_MESSAGE_DEFAULT, |
| label="System message", |
| ), |
| gr.Slider( |
| minimum=1, |
| maximum=2048, |
| value=MAX_TOKENS_DEFAULT, |
| step=1, |
| label="Max new tokens", |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=4.0, |
| value=TEMPERATURE_DEFAULT, |
| step=0.1, |
| label="Temperature", |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=TOP_P_DEFAULT, |
| step=0.05, |
| label="Top-p (nucleus sampling)", |
| ), |
| ], |
| ) |
|
|
| if __name__ == "__main__": |
| chatbot_interface.launch() |
|
|