""" LangGraph — Intelligent Model Router for STEM Copilot. Routes queries to the best free OpenRouter model based on intent: - Math / derivation / proofs → top reasoning models (rotated) - Physics / chemistry concepts → strong general models (rotated) - Casual greetings ("hi", "ok") → lightweight models - Image understanding → vision-capable models (verified) Model pools are fetched from the OpenRouter API, cached for 10 min, and rotated per-category to avoid per-model rate limits. """ from langgraph.graph import StateGraph, START, END from langgraph.graph.message import add_messages from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage from langchain_core.runnables import RunnableConfig from langchain_openrouter import ChatOpenRouter from typing import TypedDict, Annotated import sqlite3 import time import json import urllib.request import threading from langgraph.checkpoint.sqlite import SqliteSaver from config import OPENROUTER_API_KEY, DB_PATH import prompts # ── Checkpointer ────────────────────────────────────────────── _conn = sqlite3.connect(DB_PATH, check_same_thread=False) checkpointer = SqliteSaver(conn=_conn) # ── Constants ───────────────────────────────────────────────── DEFAULT_MODEL = "openai/gpt-oss-120b:free" MODELS_TTL = 10 * 60 # seconds # ── Query Classification ───────────────────────────────────── _CASUAL_EXACT = frozenset([ "hi", "hii", "hiii", "hello", "hey", "yo", "sup", "hola", "thanks", "thank you", "thankyou", "thx", "ty", "ok", "okay", "k", "kk", "fine", "alright", "bye", "goodbye", "see you", "later", "good morning", "good night", "good evening", "gm", "gn", "welcome", "cool", "nice", "great", "awesome", "wow", "got it", "understood", "sure", "yes", "no", "yeah", "nah", "yep", "nope", "hm", "hmm", "oh", "ah", "lol", "haha", "hehe", "xd", "what", "nothing", "nvm", "nevermind", "nm", ]) _REASONING_KW = [ "derive", "derivation", "prove", "proof", "solve", "equation", "integral", "integrate", "differentiate", "differentiation", "formula", "calculate", "calculation", "compute", "theorem", "lemma", "corollary", "limit", "matrix", "determinant", "vector", "eigen", "trigonometry", "trigonometric", "quadratic", "polynomial", "logarithm", "calculus", "algebra", "geometry", "coordinate", "probability", "permutation", "combination", "step by step", "show steps", "show working", "work out", "simplify", "factorize", "factorise", "expand", "numerator", "denominator", "fraction", "dx", "dy", "dz", " lim ", ] _SCIENCE_KW = [ "physics", "chemistry", "biology", "molecule", "atom", "electron", "proton", "neutron", "ion", "isotope", "nucleus", "force", "energy", "momentum", "velocity", "acceleration", "wave", "wavelength", "frequency", "amplitude", "light", "optics", "lens", "mirror", "refraction", "reflection", "electric", "magnetic", "electromagnetic", "circuit", "resistance", "gravity", "gravitational", "newton", "coulomb", "reaction", "compound", "element", "bond", "orbital", "hybridisation", "hybridization", "valence", "thermodynamics", "entropy", "enthalpy", "kinematics", "dynamics", "quantum", "relativity", "nuclear", "acid", "base", "salt", "buffer", "oxidation", "reduction", "redox", "mole", "molarity", "avogadro", "ncert", "class 11", "class 12", "class xi", "class xii", ] def _classify(text: str) -> str: """Return one of: casual, reasoning, science, general.""" t = text.strip().lower() if t in _CASUAL_EXACT or len(t) < 6: return "casual" if any(kw in t for kw in _REASONING_KW): return "reasoning" if any(kw in t for kw in _SCIENCE_KW): return "science" return "general" # ── Model Scoring ───────────────────────────────────────────── def _score(model_id: str) -> int: """Higher score ≈ better for hard reasoning tasks.""" m = model_id.lower() s = 50 for tag, pts in [ ("253b", 100), ("250b", 100), ("200b", 95), ("120b", 90), ("110b", 88), ("100b", 85), ("70b", 70), ("72b", 70), ("65b", 68), ("49b", 55), ("46b", 55), ("40b", 55), ("27b", 40), ("32b", 42), ("34b", 42), ("13b", 20), ("14b", 20), ("8b", 15), ("3b", -20), ("1b", -40), ("0.5b", -50), ]: if tag in m: s += pts break if "nemotron" in m: s += 15 if "gpt" in m: s += 10 if "llama" in m: s += 5 if "qwen" in m: s += 5 if "deepseek" in m: s += 8 return s # ── Model Pool Fetching & Caching ───────────────────────────── _cache: dict | None = None _cache_at = 0.0 _lock = threading.Lock() _counters: dict[str, int] = {} def _fetch_pools() -> dict: """Fetch free models, detect vision support, build sorted pools.""" global _cache, _cache_at with _lock: if _cache and (time.time() - _cache_at) < MODELS_TTL: return _cache try: req = urllib.request.Request( "https://openrouter.ai/api/v1/models", headers={ "HTTP-Referer": "https://stemcopilot.app", "User-Agent": "STEMCopilot/1.0", }, ) with urllib.request.urlopen(req, timeout=10) as resp: raw = json.loads(resp.read().decode()) text_all, vision_all = [], [] for m in raw.get("data", []): mid = m.get("id", "") if not mid.endswith(":free"): continue arch = m.get("architecture", {}) modality = arch.get("modality", "") input_mods = arch.get("input_modalities", []) has_vision = ("image" in modality) or ("image" in input_mods) entry = {"id": mid, "score": _score(mid), "vision": has_vision} text_all.append(entry) if has_vision: vision_all.append(entry) # Sort high→low for reasoning, low→high for casual text_all.sort(key=lambda x: x["score"], reverse=True) vision_all.sort(key=lambda x: x["score"], reverse=True) pools = { "reasoning": text_all[:6], # best 6 "science": text_all[:8], # best 8 "general": text_all[:10], # best 10 "casual": text_all[-5:][::-1] if len(text_all) >= 5 else text_all[:3], "vision": vision_all, # all vision-capable } with _lock: _cache = pools _cache_at = time.time() print(f"[ROUTER] {len(text_all)} free models, {len(vision_all)} vision-capable") for cat in ("reasoning", "casual", "vision"): ids = [e["id"] for e in pools[cat][:3]] if ids: print(f"[ROUTER] {cat}: {ids}") return pools except Exception as exc: print(f"[ROUTER] Fetch failed: {exc}") fallback_entry = {"id": DEFAULT_MODEL, "score": 50, "vision": False} return _cache or { k: [fallback_entry] for k in ("reasoning", "science", "general", "casual", "vision") } # ── Model Picker (round-robin) ──────────────────────────────── def _pick(category: str, has_image: bool = False) -> tuple[str, str]: """ Return (model_id, actual_category). If has_image but no vision model exists, falls back to text category. """ pools = _fetch_pools() if has_image: v = pools.get("vision", []) if v: idx = _counters.get("vision", 0) % len(v) _counters["vision"] = idx + 1 return v[idx]["id"], "vision" # No vision models — caller must strip images # Fall through to text routing pool = pools.get(category) or pools.get("general") or [] if not pool: return DEFAULT_MODEL, category idx = _counters.get(category, 0) % len(pool) _counters[category] = idx + 1 return pool[idx]["id"], category # ── Message helpers ──────────────────────────────────────────── def _extract_text(messages: list) -> str: """Get raw text from the last user message (handles multimodal).""" if not messages: return "" content = messages[-1].content if hasattr(messages[-1], "content") else "" if isinstance(content, list): return " ".join( p.get("text", "") for p in content if isinstance(p, dict) and p.get("type") == "text" ) return str(content) def _strip_images(messages: list) -> list: """Remove all image_url content from messages. Preserves message types.""" out = [] for msg in messages: if isinstance(msg.content, list): text_parts = [ p.get("text", "") for p in msg.content if isinstance(p, dict) and p.get("type") == "text" ] text = " ".join(t for t in text_parts if t).strip() if not text: text = "(user sent an image)" # Preserve original message class (HumanMessage, AIMessage, etc.) out.append(msg.__class__(content=text)) else: out.append(msg) return out # ── LLM factory ─────────────────────────────────────────────── def _make_llm(api_key: str, model_id: str): key = api_key or OPENROUTER_API_KEY if not key: raise ValueError( "No API key available. Please add your OpenRouter key in Settings." ) return ChatOpenRouter( model=model_id, openrouter_api_key=key, temperature=0.5, max_tokens=4096, max_retries=3, streaming=True, ) # ── LangGraph state & node ──────────────────────────────────── class ChatState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] def chat_node(state: ChatState, config: RunnableConfig): cfg = config.get("configurable", {}) persona = cfg.get("persona", "nerd") context = cfg.get("context", "") language = cfg.get("language", "auto") username = cfg.get("username", "") profile = cfg.get("student_profile", "") api_key = cfg.get("user_api_key", "") override = cfg.get("model", "") has_image = cfg.get("has_image", False) # 1) Classify user_text = _extract_text(state["messages"]) category = _classify(user_text) # 2) Pick model if override: model_id, actual = override, category else: model_id, actual = _pick(category, has_image=has_image) # 3) ALWAYS strip images from history when model is not vision-capable. # LangGraph checkpoint may contain old image messages from prior turns # that would cause "No endpoints found that support image input" on # text-only models. messages = state["messages"] if actual != "vision": messages = _strip_images(messages) # 4) Build system prompt & invoke sys = SystemMessage( content=prompts.build(persona, context, language, username, profile) ) print(f"[ROUTER] category={category} model={model_id} vision={actual == 'vision'}") llm = _make_llm(api_key, model_id) resp = llm.invoke([sys] + messages) return {"messages": [resp]} # ── Compile graph ───────────────────────────────────────────── _g = StateGraph(ChatState) _g.add_node("chat_node", chat_node) _g.add_edge(START, "chat_node") _g.add_edge("chat_node", END) chatbot = _g.compile(checkpointer=checkpointer)