from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM import torch import time # Definir todos los modelos disponibles MODELS = { "yuuki-best": "OpceanAI/Yuuki-best", "yuuki-3.7": "OpceanAI/Yuuki-3.7", "yuuki-v0.1": "OpceanAI/Yuuki-v0.1" } app = FastAPI( title="Yuuki API", description="Local inference API for Yuuki models", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Cache de modelos cargados loaded_models = {} loaded_tokenizers = {} def load_model(model_key: str): """Lazy load: solo carga el modelo cuando se necesita""" if model_key not in loaded_models: print(f"Loading {model_key}...") model_id = MODELS[model_key] loaded_tokenizers[model_key] = AutoTokenizer.from_pretrained(model_id) loaded_models[model_key] = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32 ).to("cpu") loaded_models[model_key].eval() print(f"{model_key} ready!") return loaded_models[model_key], loaded_tokenizers[model_key] class GenerateRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=4000) model: str = Field(default="yuuki-best", description="Model to use") max_new_tokens: int = Field(default=120, ge=1, le=512) temperature: float = Field(default=0.7, ge=0.1, le=2.0) top_p: float = Field(default=0.95, ge=0.0, le=1.0) class GenerateResponse(BaseModel): response: str model: str tokens_generated: int time_ms: int @app.get("/") def root(): return { "message": "Yuuki Local Inference API", "models": list(MODELS.keys()), "endpoints": { "health": "GET /health", "models": "GET /models", "generate": "POST /generate", "docs": "GET /docs" } } @app.get("/health") def health(): return { "status": "ok", "available_models": list(MODELS.keys()), "loaded_models": list(loaded_models.keys()) } @app.get("/models") def list_models(): return { "models": [ {"id": key, "name": value} for key, value in MODELS.items() ] } @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): # Validar que el modelo existe if req.model not in MODELS: raise HTTPException( status_code=400, detail=f"Invalid model. Available: {list(MODELS.keys())}" ) try: start = time.time() # Cargar modelo (lazy load) model, tokenizer = load_model(req.model) inputs = tokenizer( req.prompt, return_tensors="pt", truncation=True, max_length=1024 ) input_length = inputs["input_ids"].shape[1] with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, ) new_tokens = output[0][input_length:] response_text = tokenizer.decode(new_tokens, skip_special_tokens=True) elapsed_ms = int((time.time() - start) * 1000) return GenerateResponse( response=response_text.strip(), model=req.model, tokens_generated=len(new_tokens), time_ms=elapsed_ms ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))