Spaces:
Running
Running
| # src/llm.py | |
| # Groq LLM call with tenacity retry | |
| # Single call per inference β not a multi-step chain | |
| # Non-blocking: queued as FastAPI BackgroundTask, polled via /report/{id} | |
| import os | |
| import json | |
| import time | |
| import uuid | |
| import httpx | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| wait_exponential, | |
| retry_if_exception_type | |
| ) | |
| GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions" | |
| GROQ_MODEL = "llama-3.3-70b-versatile" | |
| MAX_TOKENS = 512 | |
| # In-memory report store: report_id β {status, report} | |
| # FastAPI polls this via GET /report/{report_id} | |
| _report_store: dict = {} | |
| class LLMAPIError(Exception): | |
| pass | |
| def _build_prompt(category: str, | |
| anomaly_score: float, | |
| similar_cases: list, | |
| graph_context: dict) -> list: | |
| """ | |
| Build LLM messages list. | |
| Strictly grounded β model must cite case IDs, cannot use outside knowledge. | |
| One call per inference. Context = retrieved cases + graph context. | |
| """ | |
| system = ( | |
| "You are an industrial quality control assistant. " | |
| "Answer ONLY based on the retrieved cases and graph context provided. " | |
| "Do not use outside knowledge. " | |
| "Always cite the Case ID when referencing a case. " | |
| "Be concise β 3 to 5 sentences maximum." | |
| ) | |
| # Build context block from retrieved similar cases | |
| context_lines = [] | |
| for i, case in enumerate(similar_cases[:5]): | |
| context_lines.append( | |
| f"[Case {i+1}: category={case.get('category')}, " | |
| f"defect={case.get('defect_type')}, " | |
| f"similarity={case.get('similarity_score', 0):.3f}]" | |
| ) | |
| # Add graph context | |
| root_causes = graph_context.get("root_causes", []) | |
| remediations = graph_context.get("remediations", []) | |
| if root_causes: | |
| context_lines.append(f"Root causes: {', '.join(root_causes)}") | |
| if remediations: | |
| context_lines.append(f"Remediations: {', '.join(remediations)}") | |
| context_str = "\n".join(context_lines) if context_lines else "No context available." | |
| user_msg = ( | |
| f"CONTEXT:\n{context_str}\n\n" | |
| f"QUERY: Image anomaly score {anomaly_score:.3f}. " | |
| f"Category: {category}. " | |
| f"Describe the likely defect, root cause, and recommended action." | |
| f"\n\nREPORT:" | |
| ) | |
| return [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user_msg} | |
| ] | |
| def _call_groq(messages: list) -> str: | |
| """ | |
| Single Groq API call with tenacity retry. | |
| Retries 3 times with 2s/4s/8s backoff on failure. | |
| Raises LLMAPIError if all 3 attempts fail. | |
| """ | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| raise LLMAPIError("GROQ_API_KEY not set in environment") | |
| try: | |
| with httpx.Client(timeout=30.0) as client: | |
| response = client.post( | |
| GROQ_API_URL, | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| }, | |
| json={ | |
| "model": GROQ_MODEL, | |
| "messages": messages, | |
| "max_tokens": MAX_TOKENS, | |
| "temperature": 0.3 # low temp = factual, grounded | |
| } | |
| ) | |
| if response.status_code == 429: | |
| raise LLMAPIError("Groq rate limit hit") | |
| if response.status_code != 200: | |
| raise LLMAPIError(f"Groq API error {response.status_code}: " | |
| f"{response.text[:200]}") | |
| data = response.json() | |
| content = data["choices"][0]["message"]["content"].strip() | |
| if not content: | |
| raise LLMAPIError("Groq returned empty response") | |
| return content | |
| except httpx.TimeoutException: | |
| raise LLMAPIError("Groq API timeout") | |
| except httpx.RequestError as e: | |
| raise LLMAPIError(f"Groq request failed: {e}") | |
| def queue_report(category: str, | |
| anomaly_score: float, | |
| similar_cases: list, | |
| graph_context: dict) -> str: | |
| """ | |
| Queue an LLM report generation. | |
| Returns report_id immediately β report generated asynchronously. | |
| Frontend polls GET /report/{report_id} every 500ms. | |
| """ | |
| report_id = str(uuid.uuid4()) | |
| _report_store[report_id] = {"status": "pending", "report": None} | |
| return report_id | |
| def generate_report(report_id: str, | |
| category: str, | |
| anomaly_score: float, | |
| similar_cases: list, | |
| graph_context: dict): | |
| """ | |
| Called as FastAPI BackgroundTask. | |
| Generates report and stores in _report_store under report_id. | |
| """ | |
| try: | |
| messages = _build_prompt(category, anomaly_score, | |
| similar_cases, graph_context) | |
| report = _call_groq(messages) | |
| _report_store[report_id] = {"status": "ready", "report": report} | |
| except LLMAPIError as e: | |
| fallback = ( | |
| "LLM temporarily unavailable. " | |
| "Retrieved cases and graph context are shown above. " | |
| f"(Error: {str(e)[:100]})" | |
| ) | |
| _report_store[report_id] = {"status": "ready", "report": fallback} | |
| except Exception as e: | |
| _report_store[report_id] = { | |
| "status": "ready", | |
| "report": "Could not generate report. Please retry." | |
| } | |
| def get_report(report_id: str) -> dict: | |
| """ | |
| Poll report status. | |
| Returns: {status: pending} or {status: ready, report: "..."} | |
| """ | |
| return _report_store.get( | |
| report_id, | |
| {"status": "not_found", "report": None} | |
| ) | |
| def cleanup_old_reports(max_age_seconds: int = 3600): | |
| """Prevent _report_store growing unbounded. Called periodically.""" | |
| # Simple approach: keep only last 500 reports | |
| if len(_report_store) > 500: | |
| keys = list(_report_store.keys()) | |
| for key in keys[:250]: | |
| del _report_store[key] |