import os import sys import pickle import torch import gradio as gr from huggingface_hub import snapshot_download # ====================== # CONFIGURACIÓN REPO HF # ====================== REPO_ID = "teszenofficial/MTP7" MODEL_FILE = "mtp_mini.pkl" # Asegúrate de que se llame así en tu repo TOKENIZER_FILE = "mtp_tokenizer.model" # Asegúrate de que se llame así en tu repo LOCAL_DIR = "mtptz_repo" # Nombre de la carpeta local donde se descarga # ====================== # DESCARGA Y CARGA DEL MODELO # ====================== def load_resources(): print(f"📦 Descargando modelo desde {REPO_ID}...") # 1. Descargar el repositorio a una carpeta local repo_path = snapshot_download( repo_id=REPO_ID, local_dir=LOCAL_DIR ) print(f"✅ Modelo descargado en: {repo_path}") # 2. Añadir la ruta al sys.path para poder importar model.py y tokenizer.py desde el repo sys.path.insert(0, repo_path) try: # Intentamos importar las clases desde los archivos descargados en el repo from model import MTPMiniModel from tokenizer import MTPTokenizer except ImportError as e: print(f"❌ ERROR: No se pudieron importar 'model' o 'tokenizer'.") print(f" Asegúrate de que subiste 'model.py' y 'tokenizer.py' al repo '{REPO_ID}'.") raise e # 3. Definir rutas completas model_path = os.path.join(repo_path, MODEL_FILE) tokenizer_path = os.path.join(repo_path, TOKENIZER_FILE) # Verificar si existen if not os.path.exists(model_path): raise FileNotFoundError(f"No se encontró {MODEL_FILE} en el repo.") if not os.path.exists(tokenizer_path): raise FileNotFoundError(f"No se encontró {TOKENIZER_FILE} en el repo.") # 4. Cargar Tokenizer tokenizer = MTPTokenizer(tokenizer_path) print(f"✅ Tokenizer cargado. Vocab size: {tokenizer.vocab_size()}") # 5. Cargar Modelo print(f"🧠 Cargando tensores...") with open(model_path, 'rb') as f: model_data = pickle.load(f) config = model_data['config'] state_dict = model_data['model_state_dict'] vocab_size = model_data['vocab_size'] # Reconstruir el Modelo use_swiglu = config['model'].get('use_swiglu', False) model = MTPMiniModel( vocab_size=vocab_size, d_model=config['model']['d_model'], n_layers=config['model']['n_layers'], n_heads=config['model']['n_heads'], d_ff=config['model']['d_ff'], max_seq_len=config['model']['max_seq_len'], dropout=0.0, use_swiglu=use_swiglu ) model.load_state_dict(state_dict) model.eval() DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(DEVICE) print(f"✅ Modelo cargado en {DEVICE}") return model, tokenizer, DEVICE # Cargar al inicio model, tokenizer, DEVICE = load_resources() # ====================== # FUNCIÓN DE GENERACIÓN # ====================== def generate_response(message, history, temperature, max_tokens, top_p): # Construir el prompt # Formato: ### Instrucción:\n{input}\n\n### Respuesta:\n prompt = f"### Instrucción:\n{message}\n\n### Respuesta:\n" # Tokenizar tokens = [tokenizer.bos_id()] + tokenizer.encode(prompt) input_ids = torch.tensor([tokens], device=DEVICE) # Generar usando el método del modelo with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=int(max_tokens), temperature=float(temperature), top_k=40, top_p=float(top_p), repetition_penalty=1.15, min_length=10, eos_token_id=tokenizer.eos_id() ) # Decodificar gen_tokens = output_ids[0, len(tokens):].tolist() safe_tokens = [] for t in gen_tokens: if 0 <= t < tokenizer.vocab_size() and t != tokenizer.eos_id(): safe_tokens.append(t) elif t == tokenizer.eos_id(): break response = tokenizer.decode(safe_tokens).strip() # Limpieza básica if "### Instrucción:" in response: response = response.split("### Instrucción:")[0].strip() return response # ====================== # INTERFAZ GRADIO # ====================== with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🤖 MTP-7 Chat (Demo)") gr.Markdown(f"Modelo cargado desde `teszenofficial/MTP7` en **{DEVICE}**.") chat_interface = gr.ChatInterface( fn=generate_response, additional_inputs=[ gr.Slider(0.1, 2.0, value=0.7, label="Temperatura (Creatividad)"), gr.Slider(50, 300, value=150, label="Máximos Tokens"), gr.Slider(0.1, 1.0, value=0.92, label="Top-p (Nucleus)"), ], examples=[ ["¿Cuál es la capital de Francia?", 0.7, 150, 0.92], ["Explica qué es la relatividad.", 0.7, 150, 0.92] ], cache_examples=False ) if __name__ == "__main__": demo.launch()