File size: 8,413 Bytes
30f67dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# chatbot_graph.py
import os
from dotenv import load_dotenv
import gradio as gr
import logging
from typing import List

load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# LLM client (Groq wrapper)
try:
    from langchain_groq import ChatGroq
except Exception:
    ChatGroq = None
    logger.warning("langchain_groq.ChatGroq not importable. Ensure langchain-groq is installed in requirements.")

from langchain_core.messages import SystemMessage, HumanMessage, AIMessage

from chatbot_retriever import retrieve_node_from_rows
from memory_store import init_db, save_message, get_last_messages, build_gradio_history

# initialize DB early
init_db()

# Instantiate Groq LLM (will require GROQ_API_KEY in env)
GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.1-8b-instant")
GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
GROQ_TEMP = float(os.getenv("GROQ_TEMP", "0.2"))

if ChatGroq:
    llm = ChatGroq(model=GROQ_MODEL, api_key=GROQ_API_KEY, temperature=GROQ_TEMP)
else:
    llm = None


def _extract_answer_from_response(response):
    # robust extraction similar to your previous helper - simplified
    try:
        if hasattr(response, "content"):
            c = response.content
            if isinstance(c, str) and c.strip():
                return c.strip()
            if isinstance(c, (list, tuple)):
                parts = [str(x) for x in c if x is not None]
                if parts:
                    return "".join(parts).strip()
            if isinstance(c, dict):
                for key in ("answer", "text", "content", "output_text", "generated_text"):
                    v = c.get(key)
                    if v:
                        if isinstance(v, (list, tuple)):
                            return "".join([str(x) for x in v]).strip()
                        return str(v).strip()
        if isinstance(response, dict):
            for key in ("answer", "text", "content"):
                v = response.get(key)
                if v:
                    return str(v)
            choices = response.get("choices") or response.get("outputs")
            if isinstance(choices, (list, tuple)) and choices:
                first = choices[0]
                if isinstance(first, dict):
                    msg = first.get("message") or first.get("text") or first.get("content")
                    if msg:
                        if isinstance(msg, (list, tuple)):
                            return "".join([str(x) for x in msg])
                        return str(msg)
        if hasattr(response, "generations"):
            gens = getattr(response, "generations")
            if gens:
                for outer in gens:
                    for g in outer:
                        if hasattr(g, "text") and g.text:
                            return str(g.text)
                        if hasattr(g, "message") and getattr(g.message, "content", None):
                            return str(g.message.content)
        s = str(response)
        if s and s.strip():
            return s.strip()
    except Exception:
        logger.exception("Failed extracting answer")
    return None


SYSTEM_PROMPT = (
    "You are PrepGraph — an accurate, concise AI tutor specialized in academic and technical content.\n"
    "Rules:\n"
    "1) Always prioritize answering the CURRENT user question directly and clearly.\n"
    "2) Refer to provided CONTEXT (delimited below) if relevant. Cite which doc (filename) or say 'from provided context' when applicable.\n"
    "3) If the current query is unclear, use ONLY the immediate previous user question to infer intent — not older ones.\n"
    "4) Provide step-by-step explanations when appropriate, using short, structured points.\n"
    "5) Include ASCII diagrams or flowcharts if they help understanding (e.g., for protocols, layers, architectures, etc.).\n"
    "6) If the context is insufficient or ambiguous, clearly say 'I’m unsure' and specify what extra information is needed.\n"
    "7) Avoid repetition, speculation, and hallucination — answer precisely what is asked.\n\n"
    "CONTEXT:\n"
)

# ---- helper: call the LLM with a list of messages (SystemMessage + HumanMessage...) ----
def call_llm(messages: List):
    if not llm:
        raise RuntimeError("LLM client (ChatGroq) not configured or import failed. Set up langchain_groq and GROQ_API_KEY.")
    # many wrappers accept the langchain message objects; keep using llm.invoke
    response = llm.invoke(messages)
    return response


# ---- Gradio UI functions ----
def load_history(user_id: str):
    uid = (user_id or os.getenv("DEFAULT_USER", "vinayak")).strip() or "vinayak"
    try:
        hist = build_gradio_history(uid)
        logger.info("Loaded %d messages for user %s", len(hist), uid)
        return hist
    except Exception:
        logger.exception("Failed to load history for %s", uid)
        return []


def chat_interface(user_input: str, chat_state: List[dict], user_id: str):
    """
    Receives user_input (string), chat_state (list of {'role':..., 'content':...}),
    user_id (string). Returns: (clear_input_str, new_chat_state)
    """
    uid = (user_id or os.getenv("DEFAULT_USER", "vinayak")).strip() or "vinayak"
    history = chat_state or []

    # Save user's message immediately
    try:
        save_message(uid, "user", user_input)
    except Exception:
        logger.exception("Failed to persist user message")

    # Build rows to pass to retriever: get last messages from DB (ensures persistence)
    rows = get_last_messages(uid, limit=200)  # chronological order

    # Retrieve context using hybrid retriever (uses last 3 user messages internally)
    try:
        retrieved = retrieve_node_from_rows(rows)
        context = retrieved.get("context")
    except Exception:
        logger.exception("Retriever failed")
        context = None

    # Build prompt: SystemMessage + last 3 user messages (HumanMessage)
    prompt_msgs = []
    system_content = SYSTEM_PROMPT + (context or "No context found.")
    prompt_msgs.append(SystemMessage(content=system_content))

    # collect last 3 user messages (from rows)
    last_users = [r[1] for r in rows if r[0] == "user"][-3:]
    if not last_users:
        # fallback to current input if DB empty
        last_users = [user_input]
    # append each of the last user messages as HumanMessage (preserves order)
    for u in last_users:
        prompt_msgs.append(HumanMessage(content=u))

    # send to LLM
    try:
        raw = call_llm(prompt_msgs)
        answer = _extract_answer_from_response(raw) or ""
    except Exception as e:
        logger.exception("LLM call failed")
        answer = f"Sorry — I couldn't process that right now ({e})."

    # persist assistant reply
    try:
        save_message(uid, "assistant", answer)
    except Exception:
        logger.exception("Failed to persist assistant message")

    # update gradio chat state: append current user and assistant
    history = history or load_history(uid)  # in case front-end was empty, rehydrate
    history.append({"role": "user", "content": user_input})
    history.append({"role": "assistant", "content": answer})

    # return: clear the input box (""), updated history for gr.Chatbot(type="messages")
    return "", history


# ---- Minimal / attractive Gradio UI ----
with gr.Blocks(css=".gradio-container {max-width:900px; margin:0 auto;}") as demo:
    gr.Markdown("# 🤖 PrepGraph — RAG Tutor")
    with gr.Row():
        user_id_input = gr.Textbox(label="User ID (will be used to persist your memory)", value=os.getenv("DEFAULT_USER", "vinayak"))
    chatbot = gr.Chatbot(label="Conversation", type="messages")

    with gr.Row():
        msg = gr.Textbox(placeholder="Ask anything about your course material...", show_label=False)
        send = gr.Button("Send")

    with gr.Row():
        clear_ui = gr.Button("Clear Chat")

    # Load history at page load (and when user_id changes)
    demo.load(load_history, [user_id_input], [chatbot])
    user_id_input.change(load_history, [user_id_input], [chatbot])

    # Bind send
    msg.submit(chat_interface, [msg, chatbot, user_id_input], [msg, chatbot])
    send.click(chat_interface, [msg, chatbot, user_id_input], [msg, chatbot])

    # just clears the UI, not the DB
    clear_ui.click(lambda: [], None, chatbot)

if __name__ == "__main__":
    demo.launch()