from __future__ import annotations import asyncio import json import os import shutil import sys import uuid from contextlib import asynccontextmanager from io import BytesIO from pathlib import Path from queue import Empty, Queue from threading import Thread from typing import Optional import uvicorn from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from PIL import Image try: from .model_utils import ( DEFAULT_DO_SAMPLE, DEFAULT_MAX_NEW_TOKENS, DEFAULT_MODEL_PATH, DEFAULT_REPETITION_PENALTY, QuantizedSkinGPTModel, ) except ImportError: from model_utils import ( DEFAULT_DO_SAMPLE, DEFAULT_MAX_NEW_TOKENS, DEFAULT_MODEL_PATH, DEFAULT_REPETITION_PENALTY, QuantizedSkinGPTModel, ) try: from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service except ImportError: sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service TEMP_DIR = Path(__file__).resolve().parents[1] / "temp_uploads" TEMP_DIR.mkdir(parents=True, exist_ok=True) DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY") deepseek_service: Optional[DeepSeekService] = None def parse_diagnosis_result(raw_text: str) -> dict: import re think_match = re.search(r"([\s\S]*?)", raw_text) answer_match = re.search(r"([\s\S]*?)", raw_text) thinking = think_match.group(1).strip() if think_match else None answer = answer_match.group(1).strip() if answer_match else None if not thinking: unclosed_think = re.search(r"([\s\S]*?)(?=|$)", raw_text) if unclosed_think: thinking = unclosed_think.group(1).strip() if not answer: unclosed_answer = re.search(r"([\s\S]*?)$", raw_text) if unclosed_answer: answer = unclosed_answer.group(1).strip() if not answer: cleaned = re.sub(r"[\s\S]*?", "", raw_text) cleaned = re.sub(r"[\s\S]*", "", cleaned) cleaned = re.sub(r"", "", cleaned) answer = cleaned.strip() or raw_text if answer: answer = re.sub(r"|", "", answer).strip() final_answer_match = re.search(r"Final Answer:\s*([\s\S]*)", answer, re.IGNORECASE) if final_answer_match: answer = final_answer_match.group(1).strip() if thinking: thinking = re.sub(r"|", "", thinking).strip() return {"thinking": thinking or None, "answer": answer, "raw": raw_text} print("Initializing INT4 Model Service...") gpt_model = QuantizedSkinGPTModel(DEFAULT_MODEL_PATH) print("INT4 service ready.") async def init_deepseek(): global deepseek_service print("\nInitializing DeepSeek service...") deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY) if deepseek_service and deepseek_service.is_loaded: print("DeepSeek service is ready!") else: print("DeepSeek service not available, will return raw results") @asynccontextmanager async def lifespan(app: FastAPI): await init_deepseek() yield print("\nShutting down INT4 service...") app = FastAPI( title="SkinGPT-R1 INT4 API", description="INT4 quantized dermatology assistant backend", version="1.1.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) chat_states = {} pending_images = {} @app.post("/v1/upload/{state_id}") async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)): del survey try: file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg" unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}" file_path = TEMP_DIR / unique_name with file_path.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) pending_images[state_id] = str(file_path) if state_id not in chat_states: chat_states[state_id] = [] return {"message": "Image uploaded successfully", "path": str(file_path)} except Exception as exc: raise HTTPException(status_code=500, detail=f"Upload failed: {exc}") from exc @app.post("/v1/predict/{state_id}") async def v1_predict(request: Request, state_id: str): try: data = await request.json() except Exception as exc: raise HTTPException(status_code=400, detail="Invalid JSON") from exc user_message = data.get("message", "") if not user_message: raise HTTPException(status_code=400, detail="Missing 'message' field") history = chat_states.get(state_id, []) current_content = [] if state_id in pending_images: img_path = pending_images.pop(state_id) current_content.append({"type": "image", "image": img_path}) if not history: user_message = f"You are a professional AI dermatology assistant.\n\n{user_message}" current_content.append({"type": "text", "text": user_message}) history.append({"role": "user", "content": current_content}) chat_states[state_id] = history try: response_text = await run_in_threadpool(gpt_model.generate_response, messages=history) except Exception as exc: chat_states[state_id].pop() raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]}) chat_states[state_id] = history return {"message": response_text} @app.post("/v1/reset/{state_id}") async def reset_chat(state_id: str): if state_id in chat_states: del chat_states[state_id] if state_id in pending_images: try: Path(pending_images[state_id]).unlink(missing_ok=True) except Exception: pass del pending_images[state_id] return {"message": "Chat history reset"} @app.get("/") async def root(): return { "name": "SkinGPT-R1 INT4 API", "version": "1.1.0", "status": "running", "description": "INT4 quantized dermatology assistant", } @app.get("/health") async def health_check(): return {"status": "healthy", "model_loaded": True} @app.post("/diagnose/stream") async def diagnose_stream( image: Optional[UploadFile] = File(None), text: str = Form(...), language: str = Form("zh"), ): language = language if language in ("zh", "en") else "zh" pil_image = None if image: contents = await image.read() pil_image = Image.open(BytesIO(contents)).convert("RGB") result_queue = Queue() generation_result = {"full_response": [], "parsed": None, "temp_image_path": None} def run_generation(): full_response = [] try: messages = [] current_content = [] system_prompt = ( "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。" ) if pil_image: temp_image_path = TEMP_DIR / f"temp_{uuid.uuid4().hex}.jpg" pil_image.save(temp_image_path) generation_result["temp_image_path"] = str(temp_image_path) current_content.append({"type": "image", "image": str(temp_image_path)}) current_content.append({"type": "text", "text": f"{system_prompt}\n\n{text}"}) messages.append({"role": "user", "content": current_content}) for chunk in gpt_model.generate_response_stream( messages=messages, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, do_sample=DEFAULT_DO_SAMPLE, repetition_penalty=DEFAULT_REPETITION_PENALTY, ): full_response.append(chunk) result_queue.put(("delta", chunk)) response_text = "".join(full_response) generation_result["full_response"] = full_response generation_result["parsed"] = parse_diagnosis_result(response_text) result_queue.put(("generation_done", None)) except Exception as exc: result_queue.put(("error", str(exc))) async def event_generator(): gen_thread = Thread(target=run_generation) gen_thread.start() loop = asyncio.get_event_loop() while True: try: msg_type, data = await loop.run_in_executor( None, lambda: result_queue.get(timeout=0.1), ) if msg_type == "generation_done": break if msg_type == "delta": yield f"data: {json.dumps({'type': 'delta', 'text': data}, ensure_ascii=False)}\n\n" elif msg_type == "error": yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n" gen_thread.join() return except Empty: await asyncio.sleep(0.01) gen_thread.join() parsed = generation_result["parsed"] if not parsed: yield "data: {\"type\": \"error\", \"message\": \"Failed to parse response\"}\n\n" return raw_thinking = parsed["thinking"] raw_answer = parsed["answer"] refined_by_deepseek = False description = None thinking = raw_thinking answer = raw_answer if deepseek_service and deepseek_service.is_loaded: try: refined = await deepseek_service.refine_diagnosis( raw_answer=raw_answer, raw_thinking=raw_thinking, language=language, ) if refined["success"]: description = refined["description"] thinking = refined["analysis_process"] answer = refined["diagnosis_result"] refined_by_deepseek = True except Exception as exc: print(f"DeepSeek refinement failed, using original: {exc}") else: print("DeepSeek service not available, using raw results") final_payload = { "description": description, "thinking": thinking, "answer": answer, "raw": parsed["raw"], "refined_by_deepseek": refined_by_deepseek, "success": True, "message": "Diagnosis completed" if language == "en" else "诊断完成", } yield f"data: {json.dumps({'type': 'final', 'result': final_payload}, ensure_ascii=False)}\n\n" temp_path = generation_result.get("temp_image_path") if temp_path: try: Path(temp_path).unlink(missing_ok=True) except Exception: pass return StreamingResponse(event_generator(), media_type="text/event-stream") def main() -> None: uvicorn.run("app:app", host="0.0.0.0", port=5901, reload=False) if __name__ == "__main__": main()