| | |
| | from __future__ import annotations |
| |
|
| | import os |
| | from typing import Any, Dict, Union |
| |
|
| | import torch |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| |
|
| |
|
| | MAX_INPUT_TOKENS = 512 |
| | DEFAULT_MAX_NEW_TOKENS = 128 |
| |
|
| | DEFAULT_SYSTEM_PROMPT = ( |
| | "You are Teapot, an open-source AI assistant optimized for low-end devices, " |
| | "providing short, accurate responses without hallucinating while excelling at " |
| | "information extraction and text summarization. " |
| | "If the context does not answer the question, reply exactly: " |
| | "'I am sorry but I don't have any information on that'." |
| | ) |
| |
|
| |
|
| | def _path_exists(p: str) -> bool: |
| | try: |
| | return os.path.exists(p) |
| | except Exception: |
| | return False |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | |
| | spiece_path = os.path.join(path, "spiece.model") |
| | tokjson_path = os.path.join(path, "tokenizer.json") |
| | cfg_path = os.path.join(path, "config.json") |
| |
|
| | print(f"[teapot] model_dir={path}") |
| | print(f"[teapot] exists config.json={_path_exists(cfg_path)} tokenizer.json={_path_exists(tokjson_path)} spiece.model={_path_exists(spiece_path)}") |
| |
|
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | path, |
| | use_fast=False, |
| | model_max_length=MAX_INPUT_TOKENS, |
| | ) |
| | self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
| |
|
| | self.device = torch.device("cpu") |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | |
| | |
| | |
| | tok_len = len(self.tokenizer) |
| | tok_vocab_size = getattr(self.tokenizer, "vocab_size", None) |
| | cfg_vocab = getattr(self.model.config, "vocab_size", None) |
| | emb_rows = int(self.model.get_input_embeddings().weight.shape[0]) |
| |
|
| | print(f"[teapot] tokenizer_class={type(self.tokenizer).__name__} use_fast={getattr(self.tokenizer, 'is_fast', None)}") |
| | print(f"[teapot] len(tokenizer)={tok_len} tokenizer.vocab_size={tok_vocab_size} model.config.vocab_size={cfg_vocab} embedding_rows={emb_rows}") |
| | print(f"[teapot] special_tokens: pad={self.tokenizer.pad_token} eos={self.tokenizer.eos_token} unk={self.tokenizer.unk_token}") |
| |
|
| | |
| | |
| | |
| | if emb_rows != tok_len: |
| | raise RuntimeError( |
| | f"[teapot] FATAL: embedding_rows ({emb_rows}) != len(tokenizer) ({tok_len}). " |
| | "This means your model weights and tokenizer files are out of sync in the repo. " |
| | "Fix by re-saving model+tokenizer together after resize_token_embeddings." |
| | ) |
| | if cfg_vocab is not None and cfg_vocab != emb_rows: |
| | raise RuntimeError( |
| | f"[teapot] FATAL: model.config.vocab_size ({cfg_vocab}) != embedding_rows ({emb_rows}). " |
| | "Your config.json is inconsistent with the weights. Re-save model to update config." |
| | ) |
| |
|
| | self.system_prompt = DEFAULT_SYSTEM_PROMPT |
| |
|
| | @torch.inference_mode() |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| | if not isinstance(data, dict) or "inputs" not in data: |
| | raise ValueError("Request must be JSON with an 'inputs' field.") |
| |
|
| | inputs: Union[str, Dict[str, Any]] = data["inputs"] |
| | params = data.get("parameters") or {} |
| |
|
| | max_new_tokens = int(params.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)) |
| |
|
| | if isinstance(inputs, str): |
| | prompt = inputs |
| | elif isinstance(inputs, dict): |
| | context = inputs.get("context", "") |
| | question = inputs.get("question", "") |
| | system_prompt = inputs.get("system_prompt", self.system_prompt) |
| | prompt = f"{context}\n{system_prompt}\n{question}\n" |
| | else: |
| | raise ValueError("'inputs' must be a string or an object with {context, question}.") |
| |
|
| | enc = self.tokenizer(prompt, return_tensors="pt") |
| | input_ids = enc["input_ids"] |
| | attention_mask = enc.get("attention_mask") |
| |
|
| | |
| | if input_ids.shape[1] > MAX_INPUT_TOKENS: |
| | input_ids = input_ids[:, -MAX_INPUT_TOKENS:] |
| | if attention_mask is not None: |
| | attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:] |
| |
|
| | input_ids = input_ids.to(self.device) |
| | if attention_mask is not None: |
| | attention_mask = attention_mask.to(self.device) |
| |
|
| | out = self.model.generate( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | do_sample=False, |
| | num_beams=1, |
| | max_new_tokens=max_new_tokens, |
| | |
| | repetition_penalty=1.05, |
| | no_repeat_ngram_size=3, |
| | ) |
| |
|
| | text = self.tokenizer.decode(out[0], skip_special_tokens=True) |
| | return {"generated_text": text} |