| |
| |
| |
| |
|
|
| import os |
| from config import MODEL, INFO, HOST |
| from openai import AsyncOpenAI |
| import gradio as gr |
|
|
| async def playground( |
| message, |
| history, |
| num_ctx, |
| max_tokens, |
| temperature, |
| repeat_penalty, |
| top_k, |
| top_p |
| ): |
| if not isinstance(message, str) or not message.strip(): |
| yield [] |
| return |
|
|
| messages = [] |
| for item in history: |
| if isinstance(item, dict) and "role" in item and "content" in item: |
| messages.append({ |
| "role": item["role"], |
| "content": item["content"] |
| }) |
| messages.append({"role": "user", "content": message}) |
|
|
| response = "" |
| stream = await AsyncOpenAI( |
| base_url=os.getenv("OLLAMA_API_BASE_URL"), |
| api_key=os.getenv("OLLAMA_API_KEY") |
| ).chat.completions.create( |
| model=MODEL, |
| messages=messages, |
| max_tokens=int(max_tokens), |
| temperature=float(temperature), |
| top_p=float(top_p), |
| stream=True, |
| extra_body={ |
| "num_ctx": int(num_ctx), |
| "repeat_penalty": float(repeat_penalty), |
| "top_k": int(top_k) |
| } |
| ) |
|
|
| async for chunk in stream: |
| if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: |
| response += chunk.choices[0].delta.content |
| yield response |
|
|
| with gr.Blocks( |
| fill_height=True, |
| fill_width=False |
| ) as app: |
| with gr.Sidebar(): |
| gr.HTML(INFO) |
| gr.Markdown("---") |
| gr.Markdown("## Model Parameters") |
| num_ctx = gr.Slider( |
| minimum=512, |
| maximum=8192, |
| value=512, |
| step=128, |
| label="Context Length", |
| info="Maximum context window size (memory)" |
| ) |
| gr.Markdown("") |
| max_tokens = gr.Slider( |
| minimum=512, |
| maximum=8192, |
| value=512, |
| step=128, |
| label="Max Tokens", |
| info="Maximum number of tokens to generate" |
| ) |
| gr.Markdown("") |
| temperature = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.1, |
| step=0.1, |
| label="Temperature", |
| info="Controls randomness in generation" |
| ) |
| gr.Markdown("") |
| repeat_penalty = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=1.05, |
| step=0.1, |
| label="Repetition Penalty", |
| info="Penalty for repeating tokens" |
| ) |
| gr.Markdown("") |
| top_k = gr.Slider( |
| minimum=0, |
| maximum=100, |
| value=50, |
| step=1, |
| label="Top K", |
| info="Number of top tokens to consider" |
| ) |
| gr.Markdown("") |
| top_p = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.1, |
| step=0.05, |
| label="Top P", |
| info="Cumulative probability threshold" |
| ) |
|
|
| gr.ChatInterface( |
| fn=playground, |
| additional_inputs=[ |
| num_ctx, |
| max_tokens, |
| temperature, |
| repeat_penalty, |
| top_k, |
| top_p |
| ], |
| type="messages", |
| examples=[ |
| ["Please introduce yourself."], |
| ["What caused World War II?"], |
| ["Give me a short introduction to large language model."], |
| ["Explain about quantum computers."] |
| ], |
| cache_examples=False, |
| show_api=False |
| ) |
|
|
| app.launch( |
| server_name=HOST, |
| pwa=True |
| ) |