| import gradio as gr |
| import os |
| from langchain.chains import ConversationalRetrievalChain |
| from langchain.memory import ConversationBufferMemory |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
| from langchain.prompts import PromptTemplate |
| from langchain_community.vectorstores import Chroma |
|
|
| def create_qa_chain(): |
| """ |
| Create the QA chain with the loaded vectorstore |
| """ |
| |
| embeddings = OpenAIEmbeddings() |
| vectorstore = Chroma( |
| persist_directory="./vectorstore", |
| embedding_function=embeddings |
| ) |
| |
| |
| retriever = vectorstore.as_retriever( |
| search_type="mmr", |
| search_kwargs={ |
| "k": 6, |
| "fetch_k": 20, |
| "lambda_mult": 0.3, |
| } |
| ) |
| |
| |
| memory = ConversationBufferMemory( |
| memory_key="chat_history", |
| return_messages=True, |
| output_key='answer' |
| ) |
| |
| |
| qa_prompt = PromptTemplate.from_template("""You are an expert technical writer specializing in API documentation. |
| When describing API endpoints, structure your response in this exact format: |
| |
| 1. Start with the HTTP method and base URI structure |
| 2. List all key parameters with: |
| - Parameter name in bold (**parameter**) |
| - Type and requirement status |
| - Clear description |
| - Example values where applicable |
| 3. Show complete example requests with: |
| - Basic example |
| - Full example with all parameters |
| - Headers included |
| 4. Include any relevant response information |
| |
| Use markdown formatting for: |
| - Code blocks with syntax highlighting |
| - Bold text for important terms |
| - Clear section separation |
| |
| Context: {context} |
| |
| Question: {question} |
| |
| Technical answer (following the exact structure above):""") |
| |
| |
| qa_chain = ConversationalRetrievalChain.from_llm( |
| llm=ChatOpenAI( |
| temperature=0.1, |
| model_name="gpt-4-turbo-preview" |
| ), |
| retriever=retriever, |
| memory=memory, |
| return_source_documents=True, |
| combine_docs_chain_kwargs={"prompt": qa_prompt}, |
| verbose=False |
| ) |
| |
| return qa_chain |
|
|
| def chat(message, history): |
| """ |
| Process chat messages and return responses |
| """ |
| |
| if not hasattr(chat, 'qa_chain'): |
| chat.qa_chain = create_qa_chain() |
| |
| |
| result = chat.qa_chain({"question": message}) |
| |
| |
| sources = "\n\nSources:\n" |
| seen_components = set() |
| shown_sources = 0 |
| |
| for doc in result["source_documents"]: |
| component = doc.metadata.get('component', '') |
| title = doc.metadata.get('title', '') |
| combo = (component, title) |
| |
| if combo not in seen_components and shown_sources < 3: |
| seen_components.add(combo) |
| shown_sources += 1 |
| sources += f"\nSource {shown_sources}:\n" |
| sources += f"Title: {title}\n" |
| sources += f"Component: {component}\n" |
| sources += f"Content: {doc.page_content[:300]}...\n" |
| |
| |
| full_response = result["answer"] + sources |
| |
| return full_response |
|
|
| demo = gr.ChatInterface( |
| chat, |
| title="Apple Music API Documentation Assistant", |
| description="Ask questions about the Apple Music API documentation.", |
| examples=[ |
| "How to search for songs on Apple Music API?", |
| "What are the required parameters for searching songs?", |
| "Show me an example request with all parameters" |
| ] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|