Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| from pathlib import Path | |
| from sambanova import SambaNova | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from chatbot import ( | |
| load_config, | |
| build_rag_corpus, | |
| retrieve_relevant_chunks, | |
| build_prompt, | |
| ask_model, | |
| format_answer, | |
| ) | |
| CONFIG_PATH = Path(__file__).parent / "config.yaml" | |
| RESOURCE_STATE = {} | |
| def init_resources(): | |
| if RESOURCE_STATE: | |
| return RESOURCE_STATE | |
| # Try to load from environment variables first (for Spaces) | |
| llm_api_key = os.getenv("SAMBANOVA_API_KEY") | |
| website = os.getenv("WEBSITE") | |
| embedding_model_name = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| system_prompt = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.") | |
| # Fallback to config.yaml if env vars not set | |
| if not llm_api_key or not website: | |
| if CONFIG_PATH.exists(): | |
| config = load_config(CONFIG_PATH) | |
| llm_api_key = llm_api_key or config.get("sambanova_api_key") | |
| website = website or config.get("website") | |
| embedding_model_name = embedding_model_name or config.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2") | |
| system_prompt = system_prompt or config.get("system_prompt", "You are a helpful assistant.") | |
| else: | |
| raise ValueError("Please set SAMBANOVA_API_KEY and WEBSITE as secrets in your Hugging Face Space settings, or provide config.yaml for local development") | |
| if not llm_api_key or not website: | |
| raise ValueError("SAMBANOVA_API_KEY and WEBSITE are required. Set them as secrets in Hugging Face Space settings.") | |
| embed_model = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
| corpus = build_rag_corpus({"embedding_model": embedding_model_name}, embed_model, website) | |
| client = SambaNova( | |
| api_key=llm_api_key, | |
| base_url="https://api.sambanova.ai/v1", | |
| timeout=30, | |
| ) | |
| RESOURCE_STATE.update( | |
| config={"embedding_model": embedding_model_name}, | |
| website=website, | |
| system_prompt=system_prompt, | |
| embed_model=embed_model, | |
| corpus=corpus, | |
| client=client, | |
| ) | |
| return RESOURCE_STATE | |
| def answer_question(question: str): | |
| resources = init_resources() | |
| selected = retrieve_relevant_chunks( | |
| resources["corpus"], | |
| question, | |
| resources["embed_model"], | |
| top_k=4, | |
| ) | |
| prompt = build_prompt(resources["system_prompt"], question, selected) | |
| raw_answer = ask_model(prompt, resources["client"]) | |
| response = format_answer(raw_answer, selected) | |
| citations = "\n\n".join( | |
| [f"Chunk {i+1}: {chunk.text[:300]}..." for i, chunk in enumerate(selected)] | |
| ) | |
| return response, citations | |
| def main(): | |
| resources = init_resources() | |
| with gr.Blocks(title="RAG Chatbot") as demo: | |
| gr.Markdown("# 🤖 RAG-Powered Chatbot") | |
| gr.Markdown(f"**Website:** {resources['website']} \n**Chunks:** {len(resources['corpus'])}") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| question_input = gr.Textbox(label="Ask a question", placeholder="What services do you provide?", lines=2) | |
| submit_button = gr.Button("Ask") | |
| answer_output = gr.Textbox(label="Answer", lines=12, interactive=False) | |
| with gr.Column(scale=1): | |
| citations_output = gr.Textbox(label="Citations", lines=20, interactive=False) | |
| submit_button.click( | |
| answer_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, citations_output], | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |