| from langchain_openai import ChatOpenAI |
| from langchain_core.messages import HumanMessage |
| import json |
| import time |
|
|
| from .config import Settings |
|
|
|
|
| def create_llm(settings: Settings) -> ChatOpenAI: |
| return ChatOpenAI( |
| model=settings.openai_model, |
| api_key=settings.openai_api_key, |
| base_url=settings.openai_base_url, |
| temperature=settings.openai_temperature, |
| max_tokens=800, |
| ) |
|
|
|
|
| def run_financial_query(compiled_graph, user_query: str) -> dict: |
| """ |
| Run one LangGraph turn. Returns memo (if Summarizer ran) and per-node step contents. |
| """ |
| initial_state = { |
| "messages": [HumanMessage(content=user_query)], |
| "steps": 0, |
| "completed_tasks": set(), |
| "pending_tasks": [], |
| } |
| run_label = user_query if len(user_query) <= 80 else user_query[:77] + "..." |
| stream_config = { |
| "run_name": run_label, |
| "tags": ["fin-agent", "langgraph"], |
| "metadata": {"app": "FinAgent"}, |
| } |
| steps: list[dict] = [] |
| memo: str | None = None |
| |
| start_time = time.time() |
| last_time = start_time |
| total_latency = 0.0 |
| |
| for output in compiled_graph.stream(initial_state, stream_config): |
| for node_name, state_update in output.items(): |
| current_time = time.time() |
| step_latency = current_time - last_time |
| total_latency = current_time - start_time |
| last_time = current_time |
| |
| if node_name == "Planner": |
| tasks = state_update.get("pending_tasks", []) |
| content = f"Generated {len(tasks)} parallel task(s): {[t['task_id'] for t in tasks]}" |
| elif node_name == "Supervisor": |
| next_agents = state_update.get("next", []) |
| if next_agents == "FINISH": |
| content = "All tasks complete. Routing to Summarizer." |
| else: |
| content = f"Dispatching tasks to: {next_agents}" |
| else: |
| messages = state_update.get("messages", []) |
| if not messages: |
| continue |
| content = messages[-1].content |
| |
| if node_name == "Summarizer": |
| memo = content |
| else: |
| steps.append({ |
| "node": node_name, |
| "content": content, |
| "step_latency": round(step_latency, 2), |
| "total_latency": round(total_latency, 2) |
| }) |
| return {"memo": memo, "steps": steps, "total_latency": round(total_latency, 2)} |
|
|
|
|
| async def astream_financial_query(compiled_graph, user_query: str): |
| """ |
| Async generator yielding Server-Sent Events (SSE) for each graph step. |
| Useful for streaming over HTTP. |
| """ |
| initial_state = { |
| "messages": [HumanMessage(content=user_query)], |
| "steps": 0, |
| "completed_tasks": set(), |
| "pending_tasks": [], |
| } |
| run_label = user_query if len(user_query) <= 80 else user_query[:77] + "..." |
| stream_config = { |
| "run_name": run_label, |
| "tags": ["fin-agent", "langgraph"], |
| "metadata": {"app": "FinAgent"}, |
| } |
| |
| start_time = time.time() |
| last_time = start_time |
| |
| async for output in compiled_graph.astream(initial_state, stream_config): |
| for node_name, state_update in output.items(): |
| current_time = time.time() |
| step_latency = current_time - last_time |
| total_latency = current_time - start_time |
| last_time = current_time |
| |
| if node_name == "Planner": |
| tasks = state_update.get("pending_tasks", []) |
| content = f"Generated {len(tasks)} parallel task(s): {[t['task_id'] for t in tasks]}" |
| elif node_name == "Supervisor": |
| next_agents = state_update.get("next", []) |
| if next_agents == "FINISH": |
| content = "All tasks complete. Routing to Summarizer." |
| else: |
| content = f"Dispatching tasks to: {next_agents}" |
| else: |
| messages = state_update.get("messages", []) |
| if not messages: |
| continue |
| content = messages[-1].content |
| |
| data = { |
| "node": node_name, |
| "content": content, |
| "step_latency": round(step_latency, 2), |
| "total_latency": round(total_latency, 2) |
| } |
| yield f"data: {json.dumps(data)}\n\n" |
|
|