| | from langgraph.graph import StateGraph, END |
| | from langgraph.checkpoint.memory import MemorySaver |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from core.state import GraphState |
| | from config.config import Config |
| |
|
| | class ProjectRAGGraph: |
| | def __init__(self, llm, vector_store): |
| | self.llm = llm |
| | self.vector_store = vector_store |
| | self.memory = MemorySaver() |
| | self.workflow = self._build_graph() |
| |
|
| | |
| | def retrieve(self, state: GraphState): |
| | try: |
| | retriever = self.vector_store.as_retriever( |
| | search_type=Config.SEARCH_TYPE, |
| | search_kwargs={ |
| | "k": Config.TOP_K, |
| | "lambda_mult": Config.LAMBDA_MULT |
| | } |
| | ) |
| | docs = retriever.invoke(state["question"]) |
| | return {"context": docs} |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Retrieval failed: {e}") |
| |
|
| | def generate(self, state: GraphState): |
| | try: |
| | prompt = ChatPromptTemplate.from_template( |
| | """ |
| | You are a professional Project Analyst. |
| | Context: |
| | {context} |
| | |
| | Question: |
| | {question} |
| | |
| | Answer strictly using the context. Cite sources. |
| | """ |
| | ) |
| |
|
| | formatted_context = "\n\n".join( |
| | doc.page_content for doc in state["context"] |
| | ) |
| |
|
| | chain = prompt | self.llm |
| | response = chain.invoke({ |
| | "context": formatted_context, |
| | "question": state["question"] |
| | }) |
| |
|
| | return {"answer": response.content} |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Answer generation failed: {e}") |
| |
|
| | |
| | def _build_graph(self): |
| | graph = StateGraph(GraphState) |
| |
|
| | graph.add_node("retrieve", self.retrieve) |
| | graph.add_node("generate", self.generate) |
| |
|
| | graph.set_entry_point("retrieve") |
| | graph.add_edge("retrieve", "generate") |
| | graph.add_edge("generate", END) |
| |
|
| | return graph.compile(checkpointer=self.memory) |
| |
|
| | def query(self, question: str, thread_id: str): |
| | try: |
| | config = {"configurable": {"thread_id": thread_id}} |
| | result = self.workflow.invoke({"question": question}, config=config) |
| | return result["answer"] |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Graph execution failed: {e}") |
| |
|