Spaces:
Sleeping
Sleeping
File size: 8,413 Bytes
30f67dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# chatbot_graph.py
import os
from dotenv import load_dotenv
import gradio as gr
import logging
from typing import List
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# LLM client (Groq wrapper)
try:
from langchain_groq import ChatGroq
except Exception:
ChatGroq = None
logger.warning("langchain_groq.ChatGroq not importable. Ensure langchain-groq is installed in requirements.")
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from chatbot_retriever import retrieve_node_from_rows
from memory_store import init_db, save_message, get_last_messages, build_gradio_history
# initialize DB early
init_db()
# Instantiate Groq LLM (will require GROQ_API_KEY in env)
GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.1-8b-instant")
GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
GROQ_TEMP = float(os.getenv("GROQ_TEMP", "0.2"))
if ChatGroq:
llm = ChatGroq(model=GROQ_MODEL, api_key=GROQ_API_KEY, temperature=GROQ_TEMP)
else:
llm = None
def _extract_answer_from_response(response):
# robust extraction similar to your previous helper - simplified
try:
if hasattr(response, "content"):
c = response.content
if isinstance(c, str) and c.strip():
return c.strip()
if isinstance(c, (list, tuple)):
parts = [str(x) for x in c if x is not None]
if parts:
return "".join(parts).strip()
if isinstance(c, dict):
for key in ("answer", "text", "content", "output_text", "generated_text"):
v = c.get(key)
if v:
if isinstance(v, (list, tuple)):
return "".join([str(x) for x in v]).strip()
return str(v).strip()
if isinstance(response, dict):
for key in ("answer", "text", "content"):
v = response.get(key)
if v:
return str(v)
choices = response.get("choices") or response.get("outputs")
if isinstance(choices, (list, tuple)) and choices:
first = choices[0]
if isinstance(first, dict):
msg = first.get("message") or first.get("text") or first.get("content")
if msg:
if isinstance(msg, (list, tuple)):
return "".join([str(x) for x in msg])
return str(msg)
if hasattr(response, "generations"):
gens = getattr(response, "generations")
if gens:
for outer in gens:
for g in outer:
if hasattr(g, "text") and g.text:
return str(g.text)
if hasattr(g, "message") and getattr(g.message, "content", None):
return str(g.message.content)
s = str(response)
if s and s.strip():
return s.strip()
except Exception:
logger.exception("Failed extracting answer")
return None
SYSTEM_PROMPT = (
"You are PrepGraph — an accurate, concise AI tutor specialized in academic and technical content.\n"
"Rules:\n"
"1) Always prioritize answering the CURRENT user question directly and clearly.\n"
"2) Refer to provided CONTEXT (delimited below) if relevant. Cite which doc (filename) or say 'from provided context' when applicable.\n"
"3) If the current query is unclear, use ONLY the immediate previous user question to infer intent — not older ones.\n"
"4) Provide step-by-step explanations when appropriate, using short, structured points.\n"
"5) Include ASCII diagrams or flowcharts if they help understanding (e.g., for protocols, layers, architectures, etc.).\n"
"6) If the context is insufficient or ambiguous, clearly say 'I’m unsure' and specify what extra information is needed.\n"
"7) Avoid repetition, speculation, and hallucination — answer precisely what is asked.\n\n"
"CONTEXT:\n"
)
# ---- helper: call the LLM with a list of messages (SystemMessage + HumanMessage...) ----
def call_llm(messages: List):
if not llm:
raise RuntimeError("LLM client (ChatGroq) not configured or import failed. Set up langchain_groq and GROQ_API_KEY.")
# many wrappers accept the langchain message objects; keep using llm.invoke
response = llm.invoke(messages)
return response
# ---- Gradio UI functions ----
def load_history(user_id: str):
uid = (user_id or os.getenv("DEFAULT_USER", "vinayak")).strip() or "vinayak"
try:
hist = build_gradio_history(uid)
logger.info("Loaded %d messages for user %s", len(hist), uid)
return hist
except Exception:
logger.exception("Failed to load history for %s", uid)
return []
def chat_interface(user_input: str, chat_state: List[dict], user_id: str):
"""
Receives user_input (string), chat_state (list of {'role':..., 'content':...}),
user_id (string). Returns: (clear_input_str, new_chat_state)
"""
uid = (user_id or os.getenv("DEFAULT_USER", "vinayak")).strip() or "vinayak"
history = chat_state or []
# Save user's message immediately
try:
save_message(uid, "user", user_input)
except Exception:
logger.exception("Failed to persist user message")
# Build rows to pass to retriever: get last messages from DB (ensures persistence)
rows = get_last_messages(uid, limit=200) # chronological order
# Retrieve context using hybrid retriever (uses last 3 user messages internally)
try:
retrieved = retrieve_node_from_rows(rows)
context = retrieved.get("context")
except Exception:
logger.exception("Retriever failed")
context = None
# Build prompt: SystemMessage + last 3 user messages (HumanMessage)
prompt_msgs = []
system_content = SYSTEM_PROMPT + (context or "No context found.")
prompt_msgs.append(SystemMessage(content=system_content))
# collect last 3 user messages (from rows)
last_users = [r[1] for r in rows if r[0] == "user"][-3:]
if not last_users:
# fallback to current input if DB empty
last_users = [user_input]
# append each of the last user messages as HumanMessage (preserves order)
for u in last_users:
prompt_msgs.append(HumanMessage(content=u))
# send to LLM
try:
raw = call_llm(prompt_msgs)
answer = _extract_answer_from_response(raw) or ""
except Exception as e:
logger.exception("LLM call failed")
answer = f"Sorry — I couldn't process that right now ({e})."
# persist assistant reply
try:
save_message(uid, "assistant", answer)
except Exception:
logger.exception("Failed to persist assistant message")
# update gradio chat state: append current user and assistant
history = history or load_history(uid) # in case front-end was empty, rehydrate
history.append({"role": "user", "content": user_input})
history.append({"role": "assistant", "content": answer})
# return: clear the input box (""), updated history for gr.Chatbot(type="messages")
return "", history
# ---- Minimal / attractive Gradio UI ----
with gr.Blocks(css=".gradio-container {max-width:900px; margin:0 auto;}") as demo:
gr.Markdown("# 🤖 PrepGraph — RAG Tutor")
with gr.Row():
user_id_input = gr.Textbox(label="User ID (will be used to persist your memory)", value=os.getenv("DEFAULT_USER", "vinayak"))
chatbot = gr.Chatbot(label="Conversation", type="messages")
with gr.Row():
msg = gr.Textbox(placeholder="Ask anything about your course material...", show_label=False)
send = gr.Button("Send")
with gr.Row():
clear_ui = gr.Button("Clear Chat")
# Load history at page load (and when user_id changes)
demo.load(load_history, [user_id_input], [chatbot])
user_id_input.change(load_history, [user_id_input], [chatbot])
# Bind send
msg.submit(chat_interface, [msg, chatbot, user_id_input], [msg, chatbot])
send.click(chat_interface, [msg, chatbot, user_id_input], [msg, chatbot])
# just clears the UI, not the DB
clear_ui.click(lambda: [], None, chatbot)
if __name__ == "__main__":
demo.launch()
|