Spaces:
Running
Running
| import sys | |
| import os | |
| import pickle | |
| import sqlite3 | |
| import pandas as pd | |
| from fastapi import FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from google import genai | |
| import chromadb | |
| from typing import List, Dict | |
| env_path = os.path.join(os.path.dirname(__file__), '.env') | |
| load_dotenv(env_path) | |
| # Add parent dir to path so we can import from middleware | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from middleware.material_predictor import predict_material_needs | |
| app = FastAPI(title="Wafer Defect API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Robust paths for Docker/Hosting | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DB_PATH = os.path.join(BASE_DIR, '..', 'middleware', 'wafer_control.db') | |
| MODEL_PATH = os.path.join(BASE_DIR, '..', 'middleware', 'material_model.pkl') | |
| CHROMA_PATH = os.path.join(BASE_DIR, 'chroma_db') | |
| # Ensure directories exist | |
| os.makedirs(CHROMA_PATH, exist_ok=True) | |
| DEFECT_COLORS = { | |
| 'Center': '#ef4444', 'Donut': '#f59e0b', 'Edge-Loc': '#10b981', | |
| 'Edge-Ring': '#3b82f6', 'Loc': '#8b5cf6', 'Random': '#ec4899', | |
| 'Scratch': '#06b6d4', 'Near-full': '#f97316', 'None': '#6b7280', | |
| 'Undetected': '#374151', | |
| } | |
| # Globally load data so we don't block requests | |
| df = pd.DataFrame() | |
| if os.path.exists(DB_PATH): | |
| print(f"Loading DB from {DB_PATH}...") | |
| conn = sqlite3.connect(DB_PATH) | |
| df = pd.read_sql_query("SELECT * FROM wafer_logs", conn) | |
| conn.close() | |
| df['scan_time'] = pd.to_datetime(df['scan_time']) | |
| df['scan_date'] = df['scan_time'].dt.date | |
| else: | |
| print(f"Warning: DB not found at {DB_PATH}. Dashboard will be empty.") | |
| # Setup Vector DB and LLM | |
| print(f"Connecting to ChromaDB at {CHROMA_PATH}...") | |
| try: | |
| chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| collection = chroma_client.get_or_create_collection(name="semiconductor_knowledge") | |
| except Exception as e: | |
| print(f"Warning: Could not connect to ChromaDB collection. Error: {e}") | |
| collection = None | |
| print("Initializing Gemini API...") | |
| gemini_client = None | |
| if os.getenv("GEMINI_API_KEY"): | |
| gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
| else: | |
| print("Warning: GEMINI_API_KEY not found in environment.") | |
| print("Loading ML model...") | |
| model_pkg = None | |
| if os.path.exists(MODEL_PATH): | |
| with open(MODEL_PATH, 'rb') as f: | |
| model_pkg = pickle.load(f) | |
| def get_kpis(): | |
| total_scans = len(df) | |
| fail_df = df[df['status'] == 'FAIL'] | |
| fail_count = len(fail_df) | |
| pass_count = len(df[df['status'] == 'PASS']) | |
| pass_rate = round((pass_count / total_scans) * 100, 1) if total_scans else 0 | |
| scrap_count = len(df[df['action'] == 'ROUTE_TO_SCRAP']) | |
| avg_waste = round(fail_df['material_wasted_pct'].mean(), 2) if fail_count else 0 | |
| avg_confidence = round(fail_df['confidence'].mean(), 2) if fail_count else 0 | |
| return { | |
| "total_scans": total_scans, | |
| "pass_count": pass_count, | |
| "pass_rate": pass_rate, | |
| "fail_count": fail_count, | |
| "fail_rate": round(100 - pass_rate, 1), | |
| "scrap_count": scrap_count, | |
| "avg_waste": avg_waste, | |
| "avg_confidence": avg_confidence | |
| } | |
| def get_defects(): | |
| fail_df = df[df['status'] == 'FAIL'] | |
| defect_counts = fail_df['defect_type'].value_counts().reset_index() | |
| defect_counts.columns = ['defect_type', 'count'] | |
| gt_counts = fail_df['ground_truth'].value_counts().reset_index() | |
| gt_counts.columns = ['ground_truth', 'count'] | |
| return { | |
| "predictions": defect_counts.to_dict(orient="records"), | |
| "ground_truth": gt_counts.head(15).to_dict(orient="records") | |
| } | |
| def get_waste(): | |
| fail_df = df[df['status'] == 'FAIL'] | |
| waste_by_type = fail_df.groupby('defect_type').agg( | |
| total_waste=('material_wasted_pct', lambda x: x.sum() / 100.0) | |
| ).reset_index().sort_values('total_waste', ascending=True) | |
| action_counts = df['action'].value_counts().reset_index() | |
| action_counts.columns = ['action', 'count'] | |
| return { | |
| "waste_by_type": waste_by_type.to_dict(orient="records"), | |
| "actions": action_counts.to_dict(orient="records") | |
| } | |
| def get_trends(): | |
| daily = df.groupby('scan_date').agg( | |
| scans=('id', 'count'), | |
| fails=('status', lambda x: (x == 'FAIL').sum()), | |
| waste=('material_wasted_pct', lambda x: x.sum() / 100.0) | |
| ).reset_index() | |
| daily['fail_rate'] = round((daily['fails'] / daily['scans']) * 100, 1) | |
| return { | |
| "dates": daily['scan_date'].astype(str).tolist(), | |
| "fail_rate": daily['fail_rate'].tolist(), | |
| "waste": daily['waste'].tolist() | |
| } | |
| def model_status(): | |
| if not model_pkg: | |
| return {"loaded": False} | |
| m = model_pkg['metrics'] | |
| imp = model_pkg['metrics']['importances'] | |
| imp_df = pd.DataFrame({'feature': list(imp.keys()), 'importance': list(imp.values())}) | |
| imp_df = imp_df.sort_values('importance', ascending=True).tail(10) | |
| return { | |
| "loaded": True, | |
| "metrics": {"r2": round(m['r2'], 4), "mae": round(m['mae'], 2)}, | |
| "importance": imp_df.to_dict(orient="records") | |
| } | |
| class PredictionRequest(BaseModel): | |
| scans: int | |
| fail_rate: float | |
| def predict_waste(req: PredictionRequest): | |
| if not model_pkg: | |
| return {"error": "No model loaded"} | |
| fail_df = df[df['status'] == 'FAIL'] | |
| dist = fail_df['defect_type'].value_counts(normalize=True).to_dict() | |
| pred = predict_material_needs(model_pkg['model'], model_pkg['feature_cols'], req.scans, req.fail_rate / 100.0, dist) | |
| pred['fail_rate'] = req.fail_rate | |
| return pred | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| def chat_with_bot(req: ChatRequest): | |
| if not gemini_client: | |
| return {"error": "Gemini API key not configured"} | |
| user_message = req.messages[-1].content if req.messages else "" | |
| # 1. RAG Retrieval from ChromaDB | |
| context_docs = "" | |
| if collection and user_message: | |
| try: | |
| results = collection.query(query_texts=[user_message], n_results=2) | |
| if results and results['documents'] and results['documents'][0]: | |
| context_docs = "\n".join(results['documents'][0]) | |
| except Exception as e: | |
| print(f"ChromaDB Query Error: {e}") | |
| # 2. Get Live Dashboard Context | |
| total_scans = len(df) | |
| fail_df = df[df['status'] == 'FAIL'] | |
| fail_count = len(fail_df) | |
| pass_rate = round(((total_scans - fail_count) / total_scans) * 100, 1) if total_scans else 0 | |
| top_defects = fail_df['defect_type'].value_counts().head(3).to_dict() | |
| live_kpis = f""" | |
| Current Dashboard State: | |
| - Total Wafers Scanned: {total_scans} | |
| - Current Pass Rate: {pass_rate}% | |
| - Total Defective Wafers: {fail_count} | |
| - Top Defect Types Right Now: {top_defects} | |
| """ | |
| # 3. Construct System Prompt | |
| system_instruction = f""" | |
| You are the 'Gorilla Semiconductors Engineering Assistant', an expert semiconductor manufacturing assistant. | |
| You help engineers understand dashboard data and troubleshoot wafer defects. | |
| Maintain a strictly professional, analytical, and authoritative engineering tone. | |
| Here is the LIVE DATA from the dashboard: | |
| {live_kpis} | |
| Here is retrieved technical context from our engineering database based on the user's query: | |
| {context_docs if context_docs else "No specific engineering docs retrieved."} | |
| Use the live data to answer questions about 'current status' or 'dashboard'. | |
| Use the engineering docs to answer questions about 'why' a defect happens. | |
| """ | |
| try: | |
| # Convert messages to format expected by google-genai | |
| contents = [] | |
| for msg in req.messages: | |
| role = "user" if msg.role == "user" else "model" | |
| contents.append( | |
| genai.types.Content(role=role, parts=[genai.types.Part.from_text(text=msg.content)]) | |
| ) | |
| response = gemini_client.models.generate_content( | |
| model='gemini-2.5-flash-lite', | |
| contents=contents, | |
| config=genai.types.GenerateContentConfig( | |
| system_instruction=system_instruction, | |
| temperature=0.3 | |
| ) | |
| ) | |
| return {"response": response.text} | |
| except Exception as e: | |
| print(f"Gemini API Error: {e}") | |
| return {"error": str(e)} | |
| # --- SERVE FRONTEND --- | |
| FRONTEND_PATH = os.path.join(BASE_DIR, "..", "frontend", "dist") | |
| if os.path.exists(FRONTEND_PATH): | |
| async def serve_frontend(full_path: str): | |
| # 1. Skip API routes | |
| if full_path.startswith("api"): | |
| return {"detail": "Not Found"} | |
| # 2. Check if the file exists (for assets like .js, .css, .png) | |
| file_path = os.path.join(FRONTEND_PATH, full_path) | |
| if os.path.isfile(file_path): | |
| return FileResponse(file_path) | |
| # 3. Fallback to index.html for React Router | |
| index_file = os.path.join(FRONTEND_PATH, "index.html") | |
| if os.path.exists(index_file): | |
| return FileResponse(index_file) | |
| return {"detail": "Frontend build not found"} | |
| else: | |
| def read_root(): | |
| return {"message": "Wafer Defect API is running. Frontend build folder not found."} | |