Yuuki-api / app.py
aguitauwu
>w<
9884bce
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
# Definir todos los modelos disponibles
MODELS = {
"yuuki-best": "OpceanAI/Yuuki-best",
"yuuki-3.7": "OpceanAI/Yuuki-3.7",
"yuuki-v0.1": "OpceanAI/Yuuki-v0.1"
}
app = FastAPI(
title="Yuuki API",
description="Local inference API for Yuuki models",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Cache de modelos cargados
loaded_models = {}
loaded_tokenizers = {}
def load_model(model_key: str):
"""Lazy load: solo carga el modelo cuando se necesita"""
if model_key not in loaded_models:
print(f"Loading {model_key}...")
model_id = MODELS[model_key]
loaded_tokenizers[model_key] = AutoTokenizer.from_pretrained(model_id)
loaded_models[model_key] = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32
).to("cpu")
loaded_models[model_key].eval()
print(f"{model_key} ready!")
return loaded_models[model_key], loaded_tokenizers[model_key]
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=4000)
model: str = Field(default="yuuki-best", description="Model to use")
max_new_tokens: int = Field(default=120, ge=1, le=512)
temperature: float = Field(default=0.7, ge=0.1, le=2.0)
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
class GenerateResponse(BaseModel):
response: str
model: str
tokens_generated: int
time_ms: int
@app.get("/")
def root():
return {
"message": "Yuuki Local Inference API",
"models": list(MODELS.keys()),
"endpoints": {
"health": "GET /health",
"models": "GET /models",
"generate": "POST /generate",
"docs": "GET /docs"
}
}
@app.get("/health")
def health():
return {
"status": "ok",
"available_models": list(MODELS.keys()),
"loaded_models": list(loaded_models.keys())
}
@app.get("/models")
def list_models():
return {
"models": [
{"id": key, "name": value}
for key, value in MODELS.items()
]
}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
# Validar que el modelo existe
if req.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Available: {list(MODELS.keys())}"
)
try:
start = time.time()
# Cargar modelo (lazy load)
model, tokenizer = load_model(req.model)
inputs = tokenizer(
req.prompt,
return_tensors="pt",
truncation=True,
max_length=1024
)
input_length = inputs["input_ids"].shape[1]
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
new_tokens = output[0][input_length:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
elapsed_ms = int((time.time() - start) * 1000)
return GenerateResponse(
response=response_text.strip(),
model=req.model,
tokens_generated=len(new_tokens),
time_ms=elapsed_ms
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))