Yuhao
Restructure inference and add INT4 serving
52a881a
from __future__ import annotations
import asyncio
import json
import os
import shutil
import sys
import uuid
from contextlib import asynccontextmanager
from io import BytesIO
from pathlib import Path
from queue import Empty, Queue
from threading import Thread
from typing import Optional
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image
try:
from .model_utils import (
DEFAULT_DO_SAMPLE,
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_MODEL_PATH,
DEFAULT_REPETITION_PENALTY,
QuantizedSkinGPTModel,
)
except ImportError:
from model_utils import (
DEFAULT_DO_SAMPLE,
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_MODEL_PATH,
DEFAULT_REPETITION_PENALTY,
QuantizedSkinGPTModel,
)
try:
from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service
except ImportError:
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service
TEMP_DIR = Path(__file__).resolve().parents[1] / "temp_uploads"
TEMP_DIR.mkdir(parents=True, exist_ok=True)
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
deepseek_service: Optional[DeepSeekService] = None
def parse_diagnosis_result(raw_text: str) -> dict:
import re
think_match = re.search(r"<think>([\s\S]*?)</think>", raw_text)
answer_match = re.search(r"<answer>([\s\S]*?)</answer>", raw_text)
thinking = think_match.group(1).strip() if think_match else None
answer = answer_match.group(1).strip() if answer_match else None
if not thinking:
unclosed_think = re.search(r"<think>([\s\S]*?)(?=<answer>|$)", raw_text)
if unclosed_think:
thinking = unclosed_think.group(1).strip()
if not answer:
unclosed_answer = re.search(r"<answer>([\s\S]*?)$", raw_text)
if unclosed_answer:
answer = unclosed_answer.group(1).strip()
if not answer:
cleaned = re.sub(r"<think>[\s\S]*?</think>", "", raw_text)
cleaned = re.sub(r"<think>[\s\S]*", "", cleaned)
cleaned = re.sub(r"</?answer>", "", cleaned)
answer = cleaned.strip() or raw_text
if answer:
answer = re.sub(r"</?think>|</?answer>", "", answer).strip()
final_answer_match = re.search(r"Final Answer:\s*([\s\S]*)", answer, re.IGNORECASE)
if final_answer_match:
answer = final_answer_match.group(1).strip()
if thinking:
thinking = re.sub(r"</?think>|</?answer>", "", thinking).strip()
return {"thinking": thinking or None, "answer": answer, "raw": raw_text}
print("Initializing INT4 Model Service...")
gpt_model = QuantizedSkinGPTModel(DEFAULT_MODEL_PATH)
print("INT4 service ready.")
async def init_deepseek():
global deepseek_service
print("\nInitializing DeepSeek service...")
deepseek_service = await get_deepseek_service(api_key=DEEPSEEK_API_KEY)
if deepseek_service and deepseek_service.is_loaded:
print("DeepSeek service is ready!")
else:
print("DeepSeek service not available, will return raw results")
@asynccontextmanager
async def lifespan(app: FastAPI):
await init_deepseek()
yield
print("\nShutting down INT4 service...")
app = FastAPI(
title="SkinGPT-R1 INT4 API",
description="INT4 quantized dermatology assistant backend",
version="1.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
chat_states = {}
pending_images = {}
@app.post("/v1/upload/{state_id}")
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
del survey
try:
file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
file_path = TEMP_DIR / unique_name
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
pending_images[state_id] = str(file_path)
if state_id not in chat_states:
chat_states[state_id] = []
return {"message": "Image uploaded successfully", "path": str(file_path)}
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Upload failed: {exc}") from exc
@app.post("/v1/predict/{state_id}")
async def v1_predict(request: Request, state_id: str):
try:
data = await request.json()
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid JSON") from exc
user_message = data.get("message", "")
if not user_message:
raise HTTPException(status_code=400, detail="Missing 'message' field")
history = chat_states.get(state_id, [])
current_content = []
if state_id in pending_images:
img_path = pending_images.pop(state_id)
current_content.append({"type": "image", "image": img_path})
if not history:
user_message = f"You are a professional AI dermatology assistant.\n\n{user_message}"
current_content.append({"type": "text", "text": user_message})
history.append({"role": "user", "content": current_content})
chat_states[state_id] = history
try:
response_text = await run_in_threadpool(gpt_model.generate_response, messages=history)
except Exception as exc:
chat_states[state_id].pop()
raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc
history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
chat_states[state_id] = history
return {"message": response_text}
@app.post("/v1/reset/{state_id}")
async def reset_chat(state_id: str):
if state_id in chat_states:
del chat_states[state_id]
if state_id in pending_images:
try:
Path(pending_images[state_id]).unlink(missing_ok=True)
except Exception:
pass
del pending_images[state_id]
return {"message": "Chat history reset"}
@app.get("/")
async def root():
return {
"name": "SkinGPT-R1 INT4 API",
"version": "1.1.0",
"status": "running",
"description": "INT4 quantized dermatology assistant",
}
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": True}
@app.post("/diagnose/stream")
async def diagnose_stream(
image: Optional[UploadFile] = File(None),
text: str = Form(...),
language: str = Form("zh"),
):
language = language if language in ("zh", "en") else "zh"
pil_image = None
if image:
contents = await image.read()
pil_image = Image.open(BytesIO(contents)).convert("RGB")
result_queue = Queue()
generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
def run_generation():
full_response = []
try:
messages = []
current_content = []
system_prompt = (
"You are a professional AI dermatology assistant."
if language == "en"
else "你是一个专业的AI皮肤科助手。"
)
if pil_image:
temp_image_path = TEMP_DIR / f"temp_{uuid.uuid4().hex}.jpg"
pil_image.save(temp_image_path)
generation_result["temp_image_path"] = str(temp_image_path)
current_content.append({"type": "image", "image": str(temp_image_path)})
current_content.append({"type": "text", "text": f"{system_prompt}\n\n{text}"})
messages.append({"role": "user", "content": current_content})
for chunk in gpt_model.generate_response_stream(
messages=messages,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
do_sample=DEFAULT_DO_SAMPLE,
repetition_penalty=DEFAULT_REPETITION_PENALTY,
):
full_response.append(chunk)
result_queue.put(("delta", chunk))
response_text = "".join(full_response)
generation_result["full_response"] = full_response
generation_result["parsed"] = parse_diagnosis_result(response_text)
result_queue.put(("generation_done", None))
except Exception as exc:
result_queue.put(("error", str(exc)))
async def event_generator():
gen_thread = Thread(target=run_generation)
gen_thread.start()
loop = asyncio.get_event_loop()
while True:
try:
msg_type, data = await loop.run_in_executor(
None,
lambda: result_queue.get(timeout=0.1),
)
if msg_type == "generation_done":
break
if msg_type == "delta":
yield f"data: {json.dumps({'type': 'delta', 'text': data}, ensure_ascii=False)}\n\n"
elif msg_type == "error":
yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
gen_thread.join()
return
except Empty:
await asyncio.sleep(0.01)
gen_thread.join()
parsed = generation_result["parsed"]
if not parsed:
yield "data: {\"type\": \"error\", \"message\": \"Failed to parse response\"}\n\n"
return
raw_thinking = parsed["thinking"]
raw_answer = parsed["answer"]
refined_by_deepseek = False
description = None
thinking = raw_thinking
answer = raw_answer
if deepseek_service and deepseek_service.is_loaded:
try:
refined = await deepseek_service.refine_diagnosis(
raw_answer=raw_answer,
raw_thinking=raw_thinking,
language=language,
)
if refined["success"]:
description = refined["description"]
thinking = refined["analysis_process"]
answer = refined["diagnosis_result"]
refined_by_deepseek = True
except Exception as exc:
print(f"DeepSeek refinement failed, using original: {exc}")
else:
print("DeepSeek service not available, using raw results")
final_payload = {
"description": description,
"thinking": thinking,
"answer": answer,
"raw": parsed["raw"],
"refined_by_deepseek": refined_by_deepseek,
"success": True,
"message": "Diagnosis completed" if language == "en" else "诊断完成",
}
yield f"data: {json.dumps({'type': 'final', 'result': final_payload}, ensure_ascii=False)}\n\n"
temp_path = generation_result.get("temp_image_path")
if temp_path:
try:
Path(temp_path).unlink(missing_ok=True)
except Exception:
pass
return StreamingResponse(event_generator(), media_type="text/event-stream")
def main() -> None:
uvicorn.run("app:app", host="0.0.0.0", port=5901, reload=False)
if __name__ == "__main__":
main()