File size: 3,574 Bytes
dcac338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from langchain_core.prompts import PromptTemplate
from langchain_chroma import Chroma
from langchain_community.llms import Ollama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_huggingface import HuggingFaceEmbeddings
import sys

def load_vectorstore():
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    vectorstore = Chroma(
        persist_directory="chroma_db",
        embedding_function=embeddings,
        collection_name="rag_code_assistant"
    )
    return vectorstore

def load_prompt():
    template = (
        "You are an expert Python coding assistant. Use the following documentation "
        "excerpts to answer the user's question accurately. \n\n"
        "IMPORTANT INSTRUCTIONS:\n"
        "- Prioritize standard Python concepts over C-extensions or advanced typing unless specifically asked.\n"
        "- If multiple types of answers are in the context, synthesize them into a complete answer.\n"
        "- If the answer is not in the context, say 'I don't know'.\n\n"
        "Context:\n"
        "{context}\n\n"
        "Question: {question}\n\n"
        "Answer:"
    )

    prompt = PromptTemplate(
        input_variables=["context", "question"],
        template=template
    )
    return prompt

def load_llm():
    llm = Ollama(model="llama3", temperature=0.1)
    return llm

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

def load_retriever(vectorstore):
    retriever = vectorstore.as_retriever(
        search_type="mmr",  # Use MMR instead of standard similarity
        search_kwargs={
            "k": 4,         # Return 4 diverse chunks to the LLM
            "fetch_k": 20   # Search top 20, then pick the 4 most diverse
        }
    )
    return retriever

def load_rag_chain(retriever, prompt, llm):
    rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    return rag_chain

def load_answer(rag_chain, retriever, question):
    print("=== ASSISTANT'S ANSWER ===")
    
    # We use .stream() instead of .invoke() for the RAG chain
    # This yields words one by one as they are generated
    full_answer = ""
    for chunk in rag_chain.stream(question):
        print(chunk, end="", flush=True)  # Print word immediately, without newlines
        full_answer += chunk
        
    print("\n") # Add a final newline when done
    
    # We still use .invoke() for the retriever to get the source documents
    # (Database retrieval happens instantly, no need to stream it)
    source_docs = retriever.invoke(question)
    
    return full_answer, source_docs

vectorstore = load_vectorstore()
prompt = load_prompt()
llm = load_llm()
retriever = load_retriever(vectorstore)
rag_chain = load_rag_chain(retriever, prompt, llm)

if __name__ == "__main__":
    vectorstore = load_vectorstore()
    prompt = load_prompt()
    llm = load_llm()
    retriever = load_retriever(vectorstore)
    rag_chain = load_rag_chain(retriever, prompt, llm)

    question = sys.argv[1] if len(sys.argv) > 1 else "What is Python?"
    print(f"\nUser Question: {question}\nThinking...\n")

    # load_answer now handles the streaming print
    answer, source_docs = load_answer(rag_chain, retriever, question)
    
    print("\n=== SOURCES USED ===")
    for idx, doc in enumerate(source_docs):
        source_file = doc.metadata.get("source", "Unknown file")
        print(f"{idx + 1}. {source_file}")