import os import re import json import base64 import threading from pathlib import Path from typing import Any # Constants from demo.py BASE_DIR = Path(".") HF_TOKEN_PATH = BASE_DIR / "hf_token" HF_TOKEN = HF_TOKEN_PATH.read_text(encoding="utf-8").strip() or None if HF_TOKEN is not None: from huggingface_hub import login login(token=HF_TOKEN, add_to_git_credential=False) HF_MODEL = os.environ.get("HF_MODEL", "google/gemma-4-E2B-it") JAILBREAK_MODEL = os.environ.get("JAILBREAK_MODEL", "DerivedFunction1/xlmr-prompt-injection") JAILBREAK_THRESHOLD = float(os.environ.get("JAILBREAK_THRESHOLD", "0.5")) REFUSAL_LANGUAGE_MODEL = os.environ.get( "REFUSAL_LANGUAGE_MODEL", "polyglot-tagger/multilabel-language-identification", ) SUPPORTED_GEMMA_LANGS = { "EN", "ES", "FR", "DE", "IT", "PT", "NL", "DA", "RU", "PL", "ZH", "JA", "KO", "VI", "HI", "BN", "TH", "ID", "MS", "MR", "TE", "TA", "GU", "PA", "AR", "TR", "HE", "SW", } SUPPORTED_JAILBREAK_LANGS = { "EN", "AR", "DE", "ES", "FR", "HI", "IT", "JA", "KO", "NL", "TH", "ZH", } # Imports for model loading from transformers import AutoProcessor, Gemma4ForConditionalGeneration, BitsAndBytesConfig, pipeline # Model loading print(f"Loading model: {HF_MODEL}") _processor = AutoProcessor.from_pretrained(HF_MODEL, padding_side="left") _bnb_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True, ) _model = Gemma4ForConditionalGeneration.from_pretrained( HF_MODEL, quantization_config=_bnb_config, device_map="auto", ) print(f"Loading jailbreak detector: {JAILBREAK_MODEL}") _jailbreak_pipe = pipeline("text-classification", model=JAILBREAK_MODEL) print(f"Loading refusal language detector: {REFUSAL_LANGUAGE_MODEL}") _refusal_language_pipe = pipeline("text-classification", model=REFUSAL_LANGUAGE_MODEL) # Tool call regex and markup stripping (from demo.py) TOOL_CALL_RE = re.compile( r"(?:<\|?tool_call\|?>|^)\s*" r"(?:call:)?(?P[a-zA-Z_]\w*)\s*" r"(?:\{|\()(?P.*?)(?:\}|\))\s*" r"(?P<\|?tool_call\|?>|||||$)", re.DOTALL, ) TOOL_CALL_MARKUP_RE = re.compile( r"<\|?tool_call\|?>.*?(?:<\|?tool_call\|?>||$)", re.DOTALL, ) def _strip_tool_call_markup(text: str) -> str: cleaned = (text or "").replace("\r", "").strip() if not cleaned: return "" cleaned = cleaned.replace("<|\"|>", '"') cleaned = TOOL_CALL_MARKUP_RE.sub("", cleaned) cleaned = re.sub(r"<\|?tool_response\|?>.*$", "", cleaned, flags=re.DOTALL) # Remove various special tokens and the REDIRECT token if present cleaned = cleaned.replace("<|turn>", "").replace("", "").replace("", "").replace("", "") cleaned = cleaned.replace("[REDIRECT]:", "") return cleaned.strip() def detect_jailbreak(text: str) -> dict: """Return detector metadata for a user message.""" result = _jailbreak_pipe(text, truncation=True, max_length=512)[0] label = str(result.get("label", "")).lower() score = float(result.get("score", 0.0)) unsafe_score = score if label == "unsafe" else (1.0 - score if label == "safe" else score) return { "score": unsafe_score, "blocked": unsafe_score >= JAILBREAK_THRESHOLD, "predicted_label": label, } def detect_refusal_language(text: str) -> str: result = _refusal_language_pipe(text, truncation=True, max_length=512)[0] label = str(result.get("label", "")).upper().strip() if label in SUPPORTED_GEMMA_LANGS: return label return "EN" def detect_preferred_language(text: str) -> str: result = _refusal_language_pipe(text, truncation=True, max_length=512)[0] label = str(result.get("label", "")).upper().strip() return label or "EN" def _sanitize_display_text(text: str, system_prompt: str | None = None) -> str: cleaned = _strip_tool_call_markup(text) if not cleaned: return "" return cleaned.strip() # These imports are needed for generate_response and generate_response_stream # They are imported here to avoid circular dependencies with demo.py from bob_resources import ( assistant_capabilities, call, validate, clarify_intent, store_policy, store_information, store_app_website, food_safety_endpoint, legal_endpoint, emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order ) def generate_response(messages: list, system_prompt: str) -> str: full = [{"role": "system", "content": system_prompt}] + messages inputs = _processor.apply_chat_template( full, tools=[assistant_capabilities, call, validate, clarify_intent, store_policy, store_information, store_app_website, food_safety_endpoint, legal_endpoint, emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ).to(_model.device) with __import__("torch").no_grad(): out = _model.generate( # pyright: ignore[reportAttributeAccessIssue] **inputs, max_new_tokens=400, temperature=0.7, do_sample=True, pad_token_id=_processor.tokenizer.eos_token_id, ) new_tokens = out[0][inputs["input_ids"].shape[1]:] return _processor.decode(new_tokens, skip_special_tokens=True).strip() def generate_response_stream(messages: list, system_prompt: str): full = [{"role": "system", "content": system_prompt}] + messages inputs = _processor.apply_chat_template( full, tools=[assistant_capabilities, call, validate, clarify_intent, store_policy, store_information, store_app_website, food_safety_endpoint, legal_endpoint, emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ).to(_model.device) from transformers import TextIteratorStreamer streamer = TextIteratorStreamer(_processor.tokenizer, skip_prompt=True, skip_special_tokens=False) thread = threading.Thread( target=_model.generate, # pyright: ignore[reportAttributeAccessIssue] kwargs={ **inputs, "max_new_tokens": 400, "temperature": 0.7, "do_sample": True, "pad_token_id": _processor.tokenizer.eos_token_id, "streamer": streamer, }, daemon=True, ) thread.start() generated = "" for chunk in streamer: generated += chunk yield chunk # Yield only the new delta chunk thread.join()