Spaces:
Running
Running
| """ | |
| LangGraph β Intelligent Model Router for STEM Copilot. | |
| Routes queries to the best free OpenRouter model based on intent: | |
| - Math / derivation / proofs β top reasoning models (rotated) | |
| - Physics / chemistry concepts β strong general models (rotated) | |
| - Casual greetings ("hi", "ok") β lightweight models | |
| - Image understanding β vision-capable models (verified) | |
| Model pools are fetched from the OpenRouter API, cached for 10 min, | |
| and rotated per-category to avoid per-model rate limits. | |
| """ | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage | |
| from langchain_core.runnables import RunnableConfig | |
| from langchain_openrouter import ChatOpenRouter | |
| from typing import TypedDict, Annotated | |
| import sqlite3 | |
| import time | |
| import json | |
| import urllib.request | |
| import threading | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| from config import OPENROUTER_API_KEY, DB_PATH | |
| import prompts | |
| # ββ Checkpointer ββββββββββββββββββββββββββββββββββββββββββββββ | |
| _conn = sqlite3.connect(DB_PATH, check_same_thread=False) | |
| checkpointer = SqliteSaver(conn=_conn) | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_MODEL = "openai/gpt-oss-120b:free" | |
| MODELS_TTL = 10 * 60 # seconds | |
| # ββ Query Classification βββββββββββββββββββββββββββββββββββββ | |
| _CASUAL_EXACT = frozenset([ | |
| "hi", "hii", "hiii", "hello", "hey", "yo", "sup", "hola", | |
| "thanks", "thank you", "thankyou", "thx", "ty", | |
| "ok", "okay", "k", "kk", "fine", "alright", | |
| "bye", "goodbye", "see you", "later", | |
| "good morning", "good night", "good evening", "gm", "gn", | |
| "welcome", "cool", "nice", "great", "awesome", "wow", | |
| "got it", "understood", "sure", "yes", "no", "yeah", "nah", "yep", "nope", | |
| "hm", "hmm", "oh", "ah", "lol", "haha", "hehe", "xd", | |
| "what", "nothing", "nvm", "nevermind", "nm", | |
| ]) | |
| _REASONING_KW = [ | |
| "derive", "derivation", "prove", "proof", "solve", "equation", | |
| "integral", "integrate", "differentiate", "differentiation", | |
| "formula", "calculate", "calculation", "compute", | |
| "theorem", "lemma", "corollary", | |
| "limit", "matrix", "determinant", "vector", "eigen", | |
| "trigonometry", "trigonometric", | |
| "quadratic", "polynomial", "logarithm", | |
| "calculus", "algebra", "geometry", "coordinate", | |
| "probability", "permutation", "combination", | |
| "step by step", "show steps", "show working", "work out", | |
| "simplify", "factorize", "factorise", "expand", | |
| "numerator", "denominator", "fraction", | |
| "dx", "dy", "dz", " lim ", | |
| ] | |
| _SCIENCE_KW = [ | |
| "physics", "chemistry", "biology", | |
| "molecule", "atom", "electron", "proton", "neutron", | |
| "ion", "isotope", "nucleus", | |
| "force", "energy", "momentum", "velocity", "acceleration", | |
| "wave", "wavelength", "frequency", "amplitude", | |
| "light", "optics", "lens", "mirror", "refraction", "reflection", | |
| "electric", "magnetic", "electromagnetic", "circuit", "resistance", | |
| "gravity", "gravitational", "newton", "coulomb", | |
| "reaction", "compound", "element", "bond", "orbital", | |
| "hybridisation", "hybridization", "valence", | |
| "thermodynamics", "entropy", "enthalpy", | |
| "kinematics", "dynamics", | |
| "quantum", "relativity", "nuclear", | |
| "acid", "base", "salt", "buffer", | |
| "oxidation", "reduction", "redox", | |
| "mole", "molarity", "avogadro", | |
| "ncert", "class 11", "class 12", "class xi", "class xii", | |
| ] | |
| def _classify(text: str) -> str: | |
| """Return one of: casual, reasoning, science, general.""" | |
| t = text.strip().lower() | |
| if t in _CASUAL_EXACT or len(t) < 6: | |
| return "casual" | |
| if any(kw in t for kw in _REASONING_KW): | |
| return "reasoning" | |
| if any(kw in t for kw in _SCIENCE_KW): | |
| return "science" | |
| return "general" | |
| # ββ Model Scoring βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _score(model_id: str) -> int: | |
| """Higher score β better for hard reasoning tasks.""" | |
| m = model_id.lower() | |
| s = 50 | |
| for tag, pts in [ | |
| ("253b", 100), ("250b", 100), ("200b", 95), | |
| ("120b", 90), ("110b", 88), ("100b", 85), | |
| ("70b", 70), ("72b", 70), ("65b", 68), | |
| ("49b", 55), ("46b", 55), ("40b", 55), | |
| ("27b", 40), ("32b", 42), ("34b", 42), | |
| ("13b", 20), ("14b", 20), ("8b", 15), | |
| ("3b", -20), ("1b", -40), ("0.5b", -50), | |
| ]: | |
| if tag in m: | |
| s += pts | |
| break | |
| if "nemotron" in m: s += 15 | |
| if "gpt" in m: s += 10 | |
| if "llama" in m: s += 5 | |
| if "qwen" in m: s += 5 | |
| if "deepseek" in m: s += 8 | |
| return s | |
| # ββ Model Pool Fetching & Caching βββββββββββββββββββββββββββββ | |
| _cache: dict | None = None | |
| _cache_at = 0.0 | |
| _lock = threading.Lock() | |
| _counters: dict[str, int] = {} | |
| def _fetch_pools() -> dict: | |
| """Fetch free models, detect vision support, build sorted pools.""" | |
| global _cache, _cache_at | |
| with _lock: | |
| if _cache and (time.time() - _cache_at) < MODELS_TTL: | |
| return _cache | |
| try: | |
| req = urllib.request.Request( | |
| "https://openrouter.ai/api/v1/models", | |
| headers={ | |
| "HTTP-Referer": "https://stemcopilot.app", | |
| "User-Agent": "STEMCopilot/1.0", | |
| }, | |
| ) | |
| with urllib.request.urlopen(req, timeout=10) as resp: | |
| raw = json.loads(resp.read().decode()) | |
| text_all, vision_all = [], [] | |
| for m in raw.get("data", []): | |
| mid = m.get("id", "") | |
| if not mid.endswith(":free"): | |
| continue | |
| arch = m.get("architecture", {}) | |
| modality = arch.get("modality", "") | |
| input_mods = arch.get("input_modalities", []) | |
| has_vision = ("image" in modality) or ("image" in input_mods) | |
| entry = {"id": mid, "score": _score(mid), "vision": has_vision} | |
| text_all.append(entry) | |
| if has_vision: | |
| vision_all.append(entry) | |
| # Sort highβlow for reasoning, lowβhigh for casual | |
| text_all.sort(key=lambda x: x["score"], reverse=True) | |
| vision_all.sort(key=lambda x: x["score"], reverse=True) | |
| pools = { | |
| "reasoning": text_all[:6], # best 6 | |
| "science": text_all[:8], # best 8 | |
| "general": text_all[:10], # best 10 | |
| "casual": text_all[-5:][::-1] if len(text_all) >= 5 else text_all[:3], | |
| "vision": vision_all, # all vision-capable | |
| } | |
| with _lock: | |
| _cache = pools | |
| _cache_at = time.time() | |
| print(f"[ROUTER] {len(text_all)} free models, {len(vision_all)} vision-capable") | |
| for cat in ("reasoning", "casual", "vision"): | |
| ids = [e["id"] for e in pools[cat][:3]] | |
| if ids: | |
| print(f"[ROUTER] {cat}: {ids}") | |
| return pools | |
| except Exception as exc: | |
| print(f"[ROUTER] Fetch failed: {exc}") | |
| fallback_entry = {"id": DEFAULT_MODEL, "score": 50, "vision": False} | |
| return _cache or { | |
| k: [fallback_entry] | |
| for k in ("reasoning", "science", "general", "casual", "vision") | |
| } | |
| # ββ Model Picker (round-robin) ββββββββββββββββββββββββββββββββ | |
| def _pick(category: str, has_image: bool = False) -> tuple[str, str]: | |
| """ | |
| Return (model_id, actual_category). | |
| If has_image but no vision model exists, falls back to text category. | |
| """ | |
| pools = _fetch_pools() | |
| if has_image: | |
| v = pools.get("vision", []) | |
| if v: | |
| idx = _counters.get("vision", 0) % len(v) | |
| _counters["vision"] = idx + 1 | |
| return v[idx]["id"], "vision" | |
| # No vision models β caller must strip images | |
| # Fall through to text routing | |
| pool = pools.get(category) or pools.get("general") or [] | |
| if not pool: | |
| return DEFAULT_MODEL, category | |
| idx = _counters.get(category, 0) % len(pool) | |
| _counters[category] = idx + 1 | |
| return pool[idx]["id"], category | |
| # ββ Message helpers ββββββββββββββββββββββββββββββββββββββββββββ | |
| def _extract_text(messages: list) -> str: | |
| """Get raw text from the last user message (handles multimodal).""" | |
| if not messages: | |
| return "" | |
| content = messages[-1].content if hasattr(messages[-1], "content") else "" | |
| if isinstance(content, list): | |
| return " ".join( | |
| p.get("text", "") for p in content | |
| if isinstance(p, dict) and p.get("type") == "text" | |
| ) | |
| return str(content) | |
| def _strip_images(messages: list) -> list: | |
| """Remove all image_url content from messages. Preserves message types.""" | |
| out = [] | |
| for msg in messages: | |
| if isinstance(msg.content, list): | |
| text_parts = [ | |
| p.get("text", "") | |
| for p in msg.content | |
| if isinstance(p, dict) and p.get("type") == "text" | |
| ] | |
| text = " ".join(t for t in text_parts if t).strip() | |
| if not text: | |
| text = "(user sent an image)" | |
| # Preserve original message class (HumanMessage, AIMessage, etc.) | |
| out.append(msg.__class__(content=text)) | |
| else: | |
| out.append(msg) | |
| return out | |
| # ββ LLM factory βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_llm(api_key: str, model_id: str): | |
| key = api_key or OPENROUTER_API_KEY | |
| if not key: | |
| raise ValueError( | |
| "No API key available. Please add your OpenRouter key in Settings." | |
| ) | |
| return ChatOpenRouter( | |
| model=model_id, | |
| openrouter_api_key=key, | |
| temperature=0.5, | |
| max_tokens=4096, | |
| max_retries=3, | |
| streaming=True, | |
| ) | |
| # ββ LangGraph state & node ββββββββββββββββββββββββββββββββββββ | |
| class ChatState(TypedDict): | |
| messages: Annotated[list[BaseMessage], add_messages] | |
| def chat_node(state: ChatState, config: RunnableConfig): | |
| cfg = config.get("configurable", {}) | |
| persona = cfg.get("persona", "nerd") | |
| context = cfg.get("context", "") | |
| language = cfg.get("language", "auto") | |
| username = cfg.get("username", "") | |
| profile = cfg.get("student_profile", "") | |
| api_key = cfg.get("user_api_key", "") | |
| override = cfg.get("model", "") | |
| has_image = cfg.get("has_image", False) | |
| # 1) Classify | |
| user_text = _extract_text(state["messages"]) | |
| category = _classify(user_text) | |
| # 2) Pick model | |
| if override: | |
| model_id, actual = override, category | |
| else: | |
| model_id, actual = _pick(category, has_image=has_image) | |
| # 3) ALWAYS strip images from history when model is not vision-capable. | |
| # LangGraph checkpoint may contain old image messages from prior turns | |
| # that would cause "No endpoints found that support image input" on | |
| # text-only models. | |
| messages = state["messages"] | |
| if actual != "vision": | |
| messages = _strip_images(messages) | |
| # 4) Build system prompt & invoke | |
| sys = SystemMessage( | |
| content=prompts.build(persona, context, language, username, profile) | |
| ) | |
| print(f"[ROUTER] category={category} model={model_id} vision={actual == 'vision'}") | |
| llm = _make_llm(api_key, model_id) | |
| resp = llm.invoke([sys] + messages) | |
| return {"messages": [resp]} | |
| # ββ Compile graph βββββββββββββββββββββββββββββββββββββββββββββ | |
| _g = StateGraph(ChatState) | |
| _g.add_node("chat_node", chat_node) | |
| _g.add_edge(START, "chat_node") | |
| _g.add_edge("chat_node", END) | |
| chatbot = _g.compile(checkpointer=checkpointer) | |