File size: 5,266 Bytes
33112c4
 
 
60ae096
33112c4
 
8334c0b
33112c4
 
21235f2
8334c0b
33112c4
 
 
 
 
 
 
 
 
 
21235f2
60ae096
 
 
 
 
 
33112c4
60ae096
21235f2
60ae096
 
 
 
 
 
 
 
 
33112c4
 
 
 
 
62aec62
21235f2
 
 
33112c4
21235f2
60ae096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33112c4
8334c0b
21235f2
33112c4
60ae096
 
21235f2
33112c4
 
 
60ae096
 
21235f2
8334c0b
 
 
 
33112c4
 
21235f2
33112c4
8334c0b
 
21235f2
60ae096
21235f2
60ae096
21235f2
 
33112c4
 
21235f2
 
33112c4
 
21235f2
33112c4
21235f2
 
62aec62
33112c4
 
60ae096
 
 
21235f2
 
33112c4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# handler.py
from __future__ import annotations

import os
from typing import Any, Dict, Union

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


MAX_INPUT_TOKENS = 512
DEFAULT_MAX_NEW_TOKENS = 128

DEFAULT_SYSTEM_PROMPT = (
    "You are Teapot, an open-source AI assistant optimized for low-end devices, "
    "providing short, accurate responses without hallucinating while excelling at "
    "information extraction and text summarization. "
    "If the context does not answer the question, reply exactly: "
    "'I am sorry but I don't have any information on that'."
)


def _path_exists(p: str) -> bool:
    try:
        return os.path.exists(p)
    except Exception:
        return False


class EndpointHandler:
    def __init__(self, path: str = ""):
        # Sanity: ensure key files exist in the mounted repo
        spiece_path = os.path.join(path, "spiece.model")
        tokjson_path = os.path.join(path, "tokenizer.json")
        cfg_path = os.path.join(path, "config.json")

        print(f"[teapot] model_dir={path}")
        print(f"[teapot] exists config.json={_path_exists(cfg_path)} tokenizer.json={_path_exists(tokjson_path)} spiece.model={_path_exists(spiece_path)}")

        # Force SentencePiece tokenizer (slow)
        self.tokenizer = AutoTokenizer.from_pretrained(
            path,
            use_fast=False,
            model_max_length=MAX_INPUT_TOKENS,
        )
        self.model = AutoModelForSeq2SeqLM.from_pretrained(path)

        self.device = torch.device("cpu")
        self.model.to(self.device)
        self.model.eval()

        # ----------------------------
        # CRITICAL CONSISTENCY CHECKS
        # ----------------------------
        tok_len = len(self.tokenizer)                       # includes added tokens
        tok_vocab_size = getattr(self.tokenizer, "vocab_size", None)  # base vocab (T5 SP)
        cfg_vocab = getattr(self.model.config, "vocab_size", None)
        emb_rows = int(self.model.get_input_embeddings().weight.shape[0])

        print(f"[teapot] tokenizer_class={type(self.tokenizer).__name__} use_fast={getattr(self.tokenizer, 'is_fast', None)}")
        print(f"[teapot] len(tokenizer)={tok_len} tokenizer.vocab_size={tok_vocab_size} model.config.vocab_size={cfg_vocab} embedding_rows={emb_rows}")
        print(f"[teapot] special_tokens: pad={self.tokenizer.pad_token} eos={self.tokenizer.eos_token} unk={self.tokenizer.unk_token}")

        # If you ever resized embeddings, these MUST match:
        # - embedding rows must equal len(tokenizer)
        # - config vocab_size should match embedding rows
        if emb_rows != tok_len:
            raise RuntimeError(
                f"[teapot] FATAL: embedding_rows ({emb_rows}) != len(tokenizer) ({tok_len}). "
                "This means your model weights and tokenizer files are out of sync in the repo. "
                "Fix by re-saving model+tokenizer together after resize_token_embeddings."
            )
        if cfg_vocab is not None and cfg_vocab != emb_rows:
            raise RuntimeError(
                f"[teapot] FATAL: model.config.vocab_size ({cfg_vocab}) != embedding_rows ({emb_rows}). "
                "Your config.json is inconsistent with the weights. Re-save model to update config."
            )

        self.system_prompt = DEFAULT_SYSTEM_PROMPT

    @torch.inference_mode()
    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        if not isinstance(data, dict) or "inputs" not in data:
            raise ValueError("Request must be JSON with an 'inputs' field.")

        inputs: Union[str, Dict[str, Any]] = data["inputs"]
        params = data.get("parameters") or {}

        max_new_tokens = int(params.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS))

        if isinstance(inputs, str):
            prompt = inputs
        elif isinstance(inputs, dict):
            context = inputs.get("context", "")
            question = inputs.get("question", "")
            system_prompt = inputs.get("system_prompt", self.system_prompt)
            prompt = f"{context}\n{system_prompt}\n{question}\n"
        else:
            raise ValueError("'inputs' must be a string or an object with {context, question}.")

        enc = self.tokenizer(prompt, return_tensors="pt")
        input_ids = enc["input_ids"]
        attention_mask = enc.get("attention_mask")

        # Keep most recent tokens (left truncate)
        if input_ids.shape[1] > MAX_INPUT_TOKENS:
            input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
            if attention_mask is not None:
                attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]

        input_ids = input_ids.to(self.device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)

        out = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            do_sample=False,
            num_beams=1,
            max_new_tokens=max_new_tokens,
            # Band-aid to prevent pathological repeats, but not a real fix:
            repetition_penalty=1.05,
            no_repeat_ngram_size=3,
        )

        text = self.tokenizer.decode(out[0], skip_special_tokens=True)
        return {"generated_text": text}