File size: 4,634 Bytes
c6d67ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
import json
import time
from .config import Settings
def create_llm(settings: Settings) -> ChatOpenAI:
return ChatOpenAI(
model=settings.openai_model,
api_key=settings.openai_api_key,
base_url=settings.openai_base_url,
temperature=settings.openai_temperature,
max_tokens=800, # CRITICAL FIX: Stops Groq from 'reserving' 8000+ tokens per API call
)
def run_financial_query(compiled_graph, user_query: str) -> dict:
"""
Run one LangGraph turn. Returns memo (if Summarizer ran) and per-node step contents.
"""
initial_state = {
"messages": [HumanMessage(content=user_query)],
"steps": 0,
"completed_tasks": set(),
"pending_tasks": [],
}
run_label = user_query if len(user_query) <= 80 else user_query[:77] + "..."
stream_config = {
"run_name": run_label,
"tags": ["fin-agent", "langgraph"],
"metadata": {"app": "FinAgent"},
}
steps: list[dict] = []
memo: str | None = None
start_time = time.time()
last_time = start_time
total_latency = 0.0
for output in compiled_graph.stream(initial_state, stream_config):
for node_name, state_update in output.items():
current_time = time.time()
step_latency = current_time - last_time
total_latency = current_time - start_time
last_time = current_time
if node_name == "Planner":
tasks = state_update.get("pending_tasks", [])
content = f"Generated {len(tasks)} parallel task(s): {[t['task_id'] for t in tasks]}"
elif node_name == "Supervisor":
next_agents = state_update.get("next", [])
if next_agents == "FINISH":
content = "All tasks complete. Routing to Summarizer."
else:
content = f"Dispatching tasks to: {next_agents}"
else:
messages = state_update.get("messages", [])
if not messages:
continue
content = messages[-1].content
if node_name == "Summarizer":
memo = content
else:
steps.append({
"node": node_name,
"content": content,
"step_latency": round(step_latency, 2),
"total_latency": round(total_latency, 2)
})
return {"memo": memo, "steps": steps, "total_latency": round(total_latency, 2)}
async def astream_financial_query(compiled_graph, user_query: str):
"""
Async generator yielding Server-Sent Events (SSE) for each graph step.
Useful for streaming over HTTP.
"""
initial_state = {
"messages": [HumanMessage(content=user_query)],
"steps": 0,
"completed_tasks": set(),
"pending_tasks": [],
}
run_label = user_query if len(user_query) <= 80 else user_query[:77] + "..."
stream_config = {
"run_name": run_label,
"tags": ["fin-agent", "langgraph"],
"metadata": {"app": "FinAgent"},
}
start_time = time.time()
last_time = start_time
async for output in compiled_graph.astream(initial_state, stream_config):
for node_name, state_update in output.items():
current_time = time.time()
step_latency = current_time - last_time
total_latency = current_time - start_time
last_time = current_time
if node_name == "Planner":
tasks = state_update.get("pending_tasks", [])
content = f"Generated {len(tasks)} parallel task(s): {[t['task_id'] for t in tasks]}"
elif node_name == "Supervisor":
next_agents = state_update.get("next", [])
if next_agents == "FINISH":
content = "All tasks complete. Routing to Summarizer."
else:
content = f"Dispatching tasks to: {next_agents}"
else:
messages = state_update.get("messages", [])
if not messages:
continue
content = messages[-1].content
data = {
"node": node_name,
"content": content,
"step_latency": round(step_latency, 2),
"total_latency": round(total_latency, 2)
}
yield f"data: {json.dumps(data)}\n\n"
|