| | import os |
| | from typing import Any, Dict, List, Optional |
| |
|
| | import pandas as pd |
| | from langchain_core.messages import HumanMessage |
| | from langchain_google_genai import ChatGoogleGenerativeAI |
| |
|
| | from agent.agent_graph import build_app |
| | from pipeline.utils_cool import df_to_payload, parse_user_choice |
| |
|
| | from .runtime_ctx import get_df_summary |
| | from dotenv import load_dotenv |
| | load_dotenv() |
| |
|
| | class ChatbotHandler: |
| | def __init__(self): |
| | self.ctx: Dict[str, Any] = { |
| | "graph_app": None, |
| | "state": { |
| | "df_payload": None, |
| | "results": [], |
| | "steps_taken": 0, |
| | "confirmed_step": None, |
| | "confirmed_params": {}, |
| | "last_task": None, |
| | "plan": None, |
| | "messages": [], |
| | "max_steps": 8, |
| | }, |
| | } |
| | |
| | |
| | self._boot_text: Optional[str] = None |
| |
|
| | |
| | self.llm = ChatGoogleGenerativeAI( |
| | model="gemini-2.5-flash-lite", |
| | temperature=0, |
| | api_key=os.getenv("GOOGLE_API_KEY"), |
| | ) |
| |
|
| | def _format_summary(self, s: Dict[str, Any]) -> str: |
| | cols = s.get("columns") or [] |
| | dtypes = s.get("dtypes") or {} |
| | shape = s.get("shape") or (None, None) |
| | label_guess = s.get("label_guess") or "None" |
| | task_guess = s.get("task_guess") or "Unknown" |
| | issues = s.get("issues") or [] |
| |
|
| | |
| | dt_pairs = [f"{k}: {v}" for k, v in list(dtypes.items())[:8]] |
| | if len(dtypes) > 8: |
| | dt_pairs.append("…") |
| |
|
| | lines = [ |
| | "### Dataset summary", |
| | f"- Shape: {shape[0]} rows × {shape[1]} columns", |
| | f"- Columns: {', '.join(map(str, cols[:10]))}{'…' if len(cols) > 10 else ''}", |
| | f"- Dtypes: {', '.join(dt_pairs)}", |
| | f"- Label guess: {label_guess}", |
| | f"- Task guess: {task_guess}", |
| | ] |
| | if issues: |
| | lines.append(f"- Potential issues: {('; '.join(issues[:3]))}{'…' if len(issues) > 3 else ''}") |
| | return "\n".join(lines) |
| |
|
| | |
| | def update_context(self, file_path: Optional[str], data_type: Optional[str], df: Optional["pd.DataFrame"]): |
| | if df is None: |
| | return "" |
| |
|
| | |
| | self.ctx["graph_app"] = build_app(self.llm) |
| | df_payload = df_to_payload(df) |
| | st = self.ctx["state"] |
| | st.update({ |
| | "df_payload": df_payload, |
| | "results": [], |
| | "steps_taken": 0, |
| | "confirmed_step": None, |
| | "confirmed_params": {}, |
| | "last_task": None, |
| | "plan": None, |
| | "messages": [HumanMessage(content="A new dataset was uploaded. Start the workflow.")], |
| | "max_steps": 8, |
| | }) |
| |
|
| | final = self.ctx["graph_app"].invoke(st) |
| | for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]: |
| | st[k] = final.get(k, st.get(k)) |
| |
|
| | |
| | s = get_df_summary() or {} |
| | summary_text = self._format_summary(s) |
| |
|
| | |
| | task_guess = (s.get("task_guess") or "").lower() |
| | label_guess = s.get("label_guess") |
| | needs_task = task_guess not in {"classification", "regression", "unsupervised"} |
| | needs_label = (task_guess in {"classification", "regression"}) and (not label_guess) |
| |
|
| | if needs_task or needs_label: |
| | ask = "\n\nPlease confirm the task" + (" and label column" if needs_label else "") + \ |
| | ". For example: `task=classification label=noisy_letter_grade`." |
| | else: |
| | ask = f"\n\nIf that looks right, say `confirm task={task_guess}" + \ |
| | (f" label={label_guess}`" if label_guess else "`") + \ |
| | " and I’ll fetch SOTA and propose a plan." |
| |
|
| | self._boot_text = summary_text + ask |
| | return self._boot_text |
| |
|
| | |
| | def respond(self, message: str, history: List): |
| | if history is None: |
| | history = [] |
| | msg = (message or "").strip() |
| | if not msg: |
| | return history, "" |
| |
|
| | |
| | |
| | if self._boot_text and len(history) == 0: |
| | history.append(("[system]", self._boot_text)) |
| | self._boot_text = None |
| |
|
| | |
| | if self.ctx.get("graph_app") is None: |
| | history.append((msg, "Please upload a dataset first.")) |
| | return history, "" |
| |
|
| | st = self.ctx["state"] |
| |
|
| | |
| | step, params = parse_user_choice(msg) |
| | if step: |
| | st["confirmed_step"] = step |
| | st["confirmed_params"] = {**(st.get("confirmed_params") or {}), **params} |
| |
|
| | |
| | messages = (st.get("messages") or []) + [HumanMessage(content=msg)] |
| | turn_state = { |
| | "messages": messages, |
| | "df_payload": st.get("df_payload"), |
| | "results": st.get("results", []), |
| | "steps_taken": st.get("steps_taken", 0), |
| | "max_steps": max(8, st.get("steps_taken", 0) + 4), |
| | "confirmed_step": st.get("confirmed_step"), |
| | "confirmed_params": st.get("confirmed_params", {}), |
| | "last_task": st.get("last_task"), |
| | "plan": st.get("plan"), |
| | } |
| |
|
| | |
| | final = self.ctx["graph_app"].invoke(turn_state) |
| |
|
| | |
| | for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]: |
| | st[k] = final.get(k, turn_state.get(k, st.get(k))) |
| |
|
| | |
| | reply = self._extract_ai_text(final.get("messages", [])) or "Done." |
| | history.append((msg, reply)) |
| | return history, "" |
| |
|
| | |
| | def _extract_ai_text(self, messages: List[Any]) -> str: |
| | def coerce_text(content: Any) -> str: |
| | if content is None: return "" |
| | if isinstance(content, str): return content |
| | if isinstance(content, list): |
| | parts = [] |
| | for c in content: |
| | if isinstance(c, dict): |
| | parts.append(str(c.get("text") or c.get("content") or c.get("data") or "")) |
| | else: |
| | parts.append(str(c)) |
| | return " ".join(p for p in parts if p) |
| | return str(content) |
| |
|
| | for m in reversed(messages or []): |
| | role = getattr(m, "type", None) or getattr(m, "role", None) |
| | if role in ("ai", "assistant", "aimessage"): |
| | return coerce_text(getattr(m, "content", None)) |
| | if isinstance(m, dict): |
| | r = (m.get("role") or m.get("type") or "").lower() |
| | if r in ("assistant", "ai", "aimessage"): |
| | return coerce_text(m.get("content")) |
| | return "" |
| |
|