StemGraph_AI / graph.py
Subh775's picture
intelligent model routing and selection..
869dc99
"""
LangGraph β€” Intelligent Model Router for STEM Copilot.
Routes queries to the best free OpenRouter model based on intent:
- Math / derivation / proofs β†’ top reasoning models (rotated)
- Physics / chemistry concepts β†’ strong general models (rotated)
- Casual greetings ("hi", "ok") β†’ lightweight models
- Image understanding β†’ vision-capable models (verified)
Model pools are fetched from the OpenRouter API, cached for 10 min,
and rotated per-category to avoid per-model rate limits.
"""
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langchain_openrouter import ChatOpenRouter
from typing import TypedDict, Annotated
import sqlite3
import time
import json
import urllib.request
import threading
from langgraph.checkpoint.sqlite import SqliteSaver
from config import OPENROUTER_API_KEY, DB_PATH
import prompts
# ── Checkpointer ──────────────────────────────────────────────
_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
checkpointer = SqliteSaver(conn=_conn)
# ── Constants ─────────────────────────────────────────────────
DEFAULT_MODEL = "openai/gpt-oss-120b:free"
MODELS_TTL = 10 * 60 # seconds
# ── Query Classification ─────────────────────────────────────
_CASUAL_EXACT = frozenset([
"hi", "hii", "hiii", "hello", "hey", "yo", "sup", "hola",
"thanks", "thank you", "thankyou", "thx", "ty",
"ok", "okay", "k", "kk", "fine", "alright",
"bye", "goodbye", "see you", "later",
"good morning", "good night", "good evening", "gm", "gn",
"welcome", "cool", "nice", "great", "awesome", "wow",
"got it", "understood", "sure", "yes", "no", "yeah", "nah", "yep", "nope",
"hm", "hmm", "oh", "ah", "lol", "haha", "hehe", "xd",
"what", "nothing", "nvm", "nevermind", "nm",
])
_REASONING_KW = [
"derive", "derivation", "prove", "proof", "solve", "equation",
"integral", "integrate", "differentiate", "differentiation",
"formula", "calculate", "calculation", "compute",
"theorem", "lemma", "corollary",
"limit", "matrix", "determinant", "vector", "eigen",
"trigonometry", "trigonometric",
"quadratic", "polynomial", "logarithm",
"calculus", "algebra", "geometry", "coordinate",
"probability", "permutation", "combination",
"step by step", "show steps", "show working", "work out",
"simplify", "factorize", "factorise", "expand",
"numerator", "denominator", "fraction",
"dx", "dy", "dz", " lim ",
]
_SCIENCE_KW = [
"physics", "chemistry", "biology",
"molecule", "atom", "electron", "proton", "neutron",
"ion", "isotope", "nucleus",
"force", "energy", "momentum", "velocity", "acceleration",
"wave", "wavelength", "frequency", "amplitude",
"light", "optics", "lens", "mirror", "refraction", "reflection",
"electric", "magnetic", "electromagnetic", "circuit", "resistance",
"gravity", "gravitational", "newton", "coulomb",
"reaction", "compound", "element", "bond", "orbital",
"hybridisation", "hybridization", "valence",
"thermodynamics", "entropy", "enthalpy",
"kinematics", "dynamics",
"quantum", "relativity", "nuclear",
"acid", "base", "salt", "buffer",
"oxidation", "reduction", "redox",
"mole", "molarity", "avogadro",
"ncert", "class 11", "class 12", "class xi", "class xii",
]
def _classify(text: str) -> str:
"""Return one of: casual, reasoning, science, general."""
t = text.strip().lower()
if t in _CASUAL_EXACT or len(t) < 6:
return "casual"
if any(kw in t for kw in _REASONING_KW):
return "reasoning"
if any(kw in t for kw in _SCIENCE_KW):
return "science"
return "general"
# ── Model Scoring ─────────────────────────────────────────────
def _score(model_id: str) -> int:
"""Higher score β‰ˆ better for hard reasoning tasks."""
m = model_id.lower()
s = 50
for tag, pts in [
("253b", 100), ("250b", 100), ("200b", 95),
("120b", 90), ("110b", 88), ("100b", 85),
("70b", 70), ("72b", 70), ("65b", 68),
("49b", 55), ("46b", 55), ("40b", 55),
("27b", 40), ("32b", 42), ("34b", 42),
("13b", 20), ("14b", 20), ("8b", 15),
("3b", -20), ("1b", -40), ("0.5b", -50),
]:
if tag in m:
s += pts
break
if "nemotron" in m: s += 15
if "gpt" in m: s += 10
if "llama" in m: s += 5
if "qwen" in m: s += 5
if "deepseek" in m: s += 8
return s
# ── Model Pool Fetching & Caching ─────────────────────────────
_cache: dict | None = None
_cache_at = 0.0
_lock = threading.Lock()
_counters: dict[str, int] = {}
def _fetch_pools() -> dict:
"""Fetch free models, detect vision support, build sorted pools."""
global _cache, _cache_at
with _lock:
if _cache and (time.time() - _cache_at) < MODELS_TTL:
return _cache
try:
req = urllib.request.Request(
"https://openrouter.ai/api/v1/models",
headers={
"HTTP-Referer": "https://stemcopilot.app",
"User-Agent": "STEMCopilot/1.0",
},
)
with urllib.request.urlopen(req, timeout=10) as resp:
raw = json.loads(resp.read().decode())
text_all, vision_all = [], []
for m in raw.get("data", []):
mid = m.get("id", "")
if not mid.endswith(":free"):
continue
arch = m.get("architecture", {})
modality = arch.get("modality", "")
input_mods = arch.get("input_modalities", [])
has_vision = ("image" in modality) or ("image" in input_mods)
entry = {"id": mid, "score": _score(mid), "vision": has_vision}
text_all.append(entry)
if has_vision:
vision_all.append(entry)
# Sort high→low for reasoning, low→high for casual
text_all.sort(key=lambda x: x["score"], reverse=True)
vision_all.sort(key=lambda x: x["score"], reverse=True)
pools = {
"reasoning": text_all[:6], # best 6
"science": text_all[:8], # best 8
"general": text_all[:10], # best 10
"casual": text_all[-5:][::-1] if len(text_all) >= 5 else text_all[:3],
"vision": vision_all, # all vision-capable
}
with _lock:
_cache = pools
_cache_at = time.time()
print(f"[ROUTER] {len(text_all)} free models, {len(vision_all)} vision-capable")
for cat in ("reasoning", "casual", "vision"):
ids = [e["id"] for e in pools[cat][:3]]
if ids:
print(f"[ROUTER] {cat}: {ids}")
return pools
except Exception as exc:
print(f"[ROUTER] Fetch failed: {exc}")
fallback_entry = {"id": DEFAULT_MODEL, "score": 50, "vision": False}
return _cache or {
k: [fallback_entry]
for k in ("reasoning", "science", "general", "casual", "vision")
}
# ── Model Picker (round-robin) ────────────────────────────────
def _pick(category: str, has_image: bool = False) -> tuple[str, str]:
"""
Return (model_id, actual_category).
If has_image but no vision model exists, falls back to text category.
"""
pools = _fetch_pools()
if has_image:
v = pools.get("vision", [])
if v:
idx = _counters.get("vision", 0) % len(v)
_counters["vision"] = idx + 1
return v[idx]["id"], "vision"
# No vision models β€” caller must strip images
# Fall through to text routing
pool = pools.get(category) or pools.get("general") or []
if not pool:
return DEFAULT_MODEL, category
idx = _counters.get(category, 0) % len(pool)
_counters[category] = idx + 1
return pool[idx]["id"], category
# ── Message helpers ────────────────────────────────────────────
def _extract_text(messages: list) -> str:
"""Get raw text from the last user message (handles multimodal)."""
if not messages:
return ""
content = messages[-1].content if hasattr(messages[-1], "content") else ""
if isinstance(content, list):
return " ".join(
p.get("text", "") for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
return str(content)
def _strip_images(messages: list) -> list:
"""Remove all image_url content from messages. Preserves message types."""
out = []
for msg in messages:
if isinstance(msg.content, list):
text_parts = [
p.get("text", "")
for p in msg.content
if isinstance(p, dict) and p.get("type") == "text"
]
text = " ".join(t for t in text_parts if t).strip()
if not text:
text = "(user sent an image)"
# Preserve original message class (HumanMessage, AIMessage, etc.)
out.append(msg.__class__(content=text))
else:
out.append(msg)
return out
# ── LLM factory ───────────────────────────────────────────────
def _make_llm(api_key: str, model_id: str):
key = api_key or OPENROUTER_API_KEY
if not key:
raise ValueError(
"No API key available. Please add your OpenRouter key in Settings."
)
return ChatOpenRouter(
model=model_id,
openrouter_api_key=key,
temperature=0.5,
max_tokens=4096,
max_retries=3,
streaming=True,
)
# ── LangGraph state & node ────────────────────────────────────
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
def chat_node(state: ChatState, config: RunnableConfig):
cfg = config.get("configurable", {})
persona = cfg.get("persona", "nerd")
context = cfg.get("context", "")
language = cfg.get("language", "auto")
username = cfg.get("username", "")
profile = cfg.get("student_profile", "")
api_key = cfg.get("user_api_key", "")
override = cfg.get("model", "")
has_image = cfg.get("has_image", False)
# 1) Classify
user_text = _extract_text(state["messages"])
category = _classify(user_text)
# 2) Pick model
if override:
model_id, actual = override, category
else:
model_id, actual = _pick(category, has_image=has_image)
# 3) ALWAYS strip images from history when model is not vision-capable.
# LangGraph checkpoint may contain old image messages from prior turns
# that would cause "No endpoints found that support image input" on
# text-only models.
messages = state["messages"]
if actual != "vision":
messages = _strip_images(messages)
# 4) Build system prompt & invoke
sys = SystemMessage(
content=prompts.build(persona, context, language, username, profile)
)
print(f"[ROUTER] category={category} model={model_id} vision={actual == 'vision'}")
llm = _make_llm(api_key, model_id)
resp = llm.invoke([sys] + messages)
return {"messages": [resp]}
# ── Compile graph ─────────────────────────────────────────────
_g = StateGraph(ChatState)
_g.add_node("chat_node", chat_node)
_g.add_edge(START, "chat_node")
_g.add_edge("chat_node", END)
chatbot = _g.compile(checkpointer=checkpointer)