| | from tclogger import logger |
| | from transformers import AutoTokenizer |
| |
|
| | from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED |
| |
|
| |
|
| | class TokenChecker: |
| | def __init__(self, input_str: str, model: str): |
| | self.input_str = input_str |
| |
|
| | if model in MODEL_MAP.keys(): |
| | self.model = model |
| | else: |
| | self.model = "nous-mixtral-8x7b" |
| |
|
| | self.model_fullname = MODEL_MAP[self.model] |
| |
|
| | |
| | GATED_MODEL_MAP = { |
| | "llama3-70b": "NousResearch/Meta-Llama-3-70B", |
| | "gemma-7b": "unsloth/gemma-7b", |
| | "mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2", |
| | "mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1", |
| | } |
| | if self.model in GATED_MODEL_MAP.keys(): |
| | self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model]) |
| | else: |
| | self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname) |
| |
|
| | def count_tokens(self): |
| | token_count = len(self.tokenizer.encode(self.input_str)) |
| | logger.note(f"Prompt Token Count: {token_count}") |
| | return token_count |
| |
|
| | def get_token_limit(self): |
| | return TOKEN_LIMIT_MAP[self.model] |
| |
|
| | def get_token_redundancy(self): |
| | return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens()) |
| |
|
| | def check_token_limit(self): |
| | if self.get_token_redundancy() <= 0: |
| | raise ValueError( |
| | f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}" |
| | ) |
| | return True |
| |
|