File size: 5,266 Bytes
33112c4 60ae096 33112c4 8334c0b 33112c4 21235f2 8334c0b 33112c4 21235f2 60ae096 33112c4 60ae096 21235f2 60ae096 33112c4 62aec62 21235f2 33112c4 21235f2 60ae096 33112c4 8334c0b 21235f2 33112c4 60ae096 21235f2 33112c4 60ae096 21235f2 8334c0b 33112c4 21235f2 33112c4 8334c0b 21235f2 60ae096 21235f2 60ae096 21235f2 33112c4 21235f2 33112c4 21235f2 33112c4 21235f2 62aec62 33112c4 60ae096 21235f2 33112c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | # handler.py
from __future__ import annotations
import os
from typing import Any, Dict, Union
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
MAX_INPUT_TOKENS = 512
DEFAULT_MAX_NEW_TOKENS = 128
DEFAULT_SYSTEM_PROMPT = (
"You are Teapot, an open-source AI assistant optimized for low-end devices, "
"providing short, accurate responses without hallucinating while excelling at "
"information extraction and text summarization. "
"If the context does not answer the question, reply exactly: "
"'I am sorry but I don't have any information on that'."
)
def _path_exists(p: str) -> bool:
try:
return os.path.exists(p)
except Exception:
return False
class EndpointHandler:
def __init__(self, path: str = ""):
# Sanity: ensure key files exist in the mounted repo
spiece_path = os.path.join(path, "spiece.model")
tokjson_path = os.path.join(path, "tokenizer.json")
cfg_path = os.path.join(path, "config.json")
print(f"[teapot] model_dir={path}")
print(f"[teapot] exists config.json={_path_exists(cfg_path)} tokenizer.json={_path_exists(tokjson_path)} spiece.model={_path_exists(spiece_path)}")
# Force SentencePiece tokenizer (slow)
self.tokenizer = AutoTokenizer.from_pretrained(
path,
use_fast=False,
model_max_length=MAX_INPUT_TOKENS,
)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.device = torch.device("cpu")
self.model.to(self.device)
self.model.eval()
# ----------------------------
# CRITICAL CONSISTENCY CHECKS
# ----------------------------
tok_len = len(self.tokenizer) # includes added tokens
tok_vocab_size = getattr(self.tokenizer, "vocab_size", None) # base vocab (T5 SP)
cfg_vocab = getattr(self.model.config, "vocab_size", None)
emb_rows = int(self.model.get_input_embeddings().weight.shape[0])
print(f"[teapot] tokenizer_class={type(self.tokenizer).__name__} use_fast={getattr(self.tokenizer, 'is_fast', None)}")
print(f"[teapot] len(tokenizer)={tok_len} tokenizer.vocab_size={tok_vocab_size} model.config.vocab_size={cfg_vocab} embedding_rows={emb_rows}")
print(f"[teapot] special_tokens: pad={self.tokenizer.pad_token} eos={self.tokenizer.eos_token} unk={self.tokenizer.unk_token}")
# If you ever resized embeddings, these MUST match:
# - embedding rows must equal len(tokenizer)
# - config vocab_size should match embedding rows
if emb_rows != tok_len:
raise RuntimeError(
f"[teapot] FATAL: embedding_rows ({emb_rows}) != len(tokenizer) ({tok_len}). "
"This means your model weights and tokenizer files are out of sync in the repo. "
"Fix by re-saving model+tokenizer together after resize_token_embeddings."
)
if cfg_vocab is not None and cfg_vocab != emb_rows:
raise RuntimeError(
f"[teapot] FATAL: model.config.vocab_size ({cfg_vocab}) != embedding_rows ({emb_rows}). "
"Your config.json is inconsistent with the weights. Re-save model to update config."
)
self.system_prompt = DEFAULT_SYSTEM_PROMPT
@torch.inference_mode()
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
if not isinstance(data, dict) or "inputs" not in data:
raise ValueError("Request must be JSON with an 'inputs' field.")
inputs: Union[str, Dict[str, Any]] = data["inputs"]
params = data.get("parameters") or {}
max_new_tokens = int(params.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS))
if isinstance(inputs, str):
prompt = inputs
elif isinstance(inputs, dict):
context = inputs.get("context", "")
question = inputs.get("question", "")
system_prompt = inputs.get("system_prompt", self.system_prompt)
prompt = f"{context}\n{system_prompt}\n{question}\n"
else:
raise ValueError("'inputs' must be a string or an object with {context, question}.")
enc = self.tokenizer(prompt, return_tensors="pt")
input_ids = enc["input_ids"]
attention_mask = enc.get("attention_mask")
# Keep most recent tokens (left truncate)
if input_ids.shape[1] > MAX_INPUT_TOKENS:
input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
if attention_mask is not None:
attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
input_ids = input_ids.to(self.device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.device)
out = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=False,
num_beams=1,
max_new_tokens=max_new_tokens,
# Band-aid to prevent pathological repeats, but not a real fix:
repetition_penalty=1.05,
no_repeat_ngram_size=3,
)
text = self.tokenizer.decode(out[0], skip_special_tokens=True)
return {"generated_text": text} |