| import os |
| import re |
| import json |
| import base64 |
| import threading |
| from pathlib import Path |
| from typing import Any |
|
|
| |
| BASE_DIR = Path(".") |
| HF_TOKEN_PATH = BASE_DIR / "hf_token" |
| HF_TOKEN = HF_TOKEN_PATH.read_text(encoding="utf-8").strip() or None |
| if HF_TOKEN is not None: |
| from huggingface_hub import login |
| login(token=HF_TOKEN, add_to_git_credential=False) |
| HF_MODEL = os.environ.get("HF_MODEL", "google/gemma-4-E2B-it") |
| JAILBREAK_MODEL = os.environ.get("JAILBREAK_MODEL", "DerivedFunction1/xlmr-prompt-injection") |
| JAILBREAK_THRESHOLD = float(os.environ.get("JAILBREAK_THRESHOLD", "0.5")) |
| REFUSAL_LANGUAGE_MODEL = os.environ.get( |
| "REFUSAL_LANGUAGE_MODEL", |
| "polyglot-tagger/multilabel-language-identification", |
| ) |
|
|
| SUPPORTED_GEMMA_LANGS = { |
| "EN", "ES", "FR", "DE", "IT", "PT", "NL", |
| "DA", "RU", "PL", |
| "ZH", "JA", "KO", "VI", |
| "HI", "BN", "TH", "ID", "MS", "MR", "TE", "TA", "GU", "PA", |
| "AR", "TR", "HE", "SW", |
| } |
|
|
| SUPPORTED_JAILBREAK_LANGS = { |
| "EN", "AR", "DE", "ES", "FR", "HI", "IT", "JA", "KO", "NL", "TH", "ZH", |
| } |
|
|
| |
| from transformers import AutoProcessor, Gemma4ForConditionalGeneration, BitsAndBytesConfig, pipeline |
|
|
| |
| print(f"Loading model: {HF_MODEL}") |
| _processor = AutoProcessor.from_pretrained(HF_MODEL, padding_side="left") |
| _bnb_config = BitsAndBytesConfig( |
| load_in_8bit=True, |
| llm_int8_enable_fp32_cpu_offload=True, |
| ) |
| _model = Gemma4ForConditionalGeneration.from_pretrained( |
| HF_MODEL, |
| quantization_config=_bnb_config, |
| device_map="auto", |
| ) |
|
|
| print(f"Loading jailbreak detector: {JAILBREAK_MODEL}") |
| _jailbreak_pipe = pipeline("text-classification", model=JAILBREAK_MODEL) |
|
|
| print(f"Loading refusal language detector: {REFUSAL_LANGUAGE_MODEL}") |
| _refusal_language_pipe = pipeline("text-classification", model=REFUSAL_LANGUAGE_MODEL) |
|
|
| |
| TOOL_CALL_RE = re.compile( |
| r"(?:<\|?tool_call\|?>|^)\s*" |
| r"(?:call:)?(?P<name>[a-zA-Z_]\w*)\s*" |
| r"(?:\{|\()(?P<args>.*?)(?:\}|\))\s*" |
| r"(?P<close><\|?tool_call\|?>|<eos>|<end_of_turn>|<turn\|?>|</s>|$)", |
| re.DOTALL, |
| ) |
|
|
| TOOL_CALL_MARKUP_RE = re.compile( |
| r"<\|?tool_call\|?>.*?(?:<\|?tool_call\|?>|<eos>|$)", |
| re.DOTALL, |
| ) |
|
|
|
|
| def _strip_tool_call_markup(text: str) -> str: |
| cleaned = (text or "").replace("\r", "").strip() |
| if not cleaned: |
| return "" |
|
|
| cleaned = cleaned.replace("<|\"|>", '"') |
| cleaned = TOOL_CALL_MARKUP_RE.sub("", cleaned) |
| cleaned = re.sub(r"<\|?tool_response\|?>.*$", "", cleaned, flags=re.DOTALL) |
| |
| cleaned = cleaned.replace("<|turn>", "").replace("<turn|>", "").replace("<eos>", "").replace("</s>", "") |
| cleaned = cleaned.replace("[REDIRECT]:", "") |
| return cleaned.strip() |
|
|
|
|
| def detect_jailbreak(text: str) -> dict: |
| """Return detector metadata for a user message.""" |
| result = _jailbreak_pipe(text, truncation=True, max_length=512)[0] |
| label = str(result.get("label", "")).lower() |
| score = float(result.get("score", 0.0)) |
| unsafe_score = score if label == "unsafe" else (1.0 - score if label == "safe" else score) |
|
|
| return { |
| "score": unsafe_score, |
| "blocked": unsafe_score >= JAILBREAK_THRESHOLD, |
| "predicted_label": label, |
| } |
|
|
|
|
| def detect_refusal_language(text: str) -> str: |
| result = _refusal_language_pipe(text, truncation=True, max_length=512)[0] |
| label = str(result.get("label", "")).upper().strip() |
| if label in SUPPORTED_GEMMA_LANGS: |
| return label |
| return "EN" |
|
|
|
|
| def detect_preferred_language(text: str) -> str: |
| result = _refusal_language_pipe(text, truncation=True, max_length=512)[0] |
| label = str(result.get("label", "")).upper().strip() |
| return label or "EN" |
|
|
|
|
| def _sanitize_display_text(text: str, system_prompt: str | None = None) -> str: |
| cleaned = _strip_tool_call_markup(text) |
| if not cleaned: |
| return "" |
|
|
| return cleaned.strip() |
|
|
|
|
| |
| |
| from bob_resources import ( |
| assistant_capabilities, |
| call, |
| validate, |
| clarify_intent, |
| store_policy, |
| store_information, |
| store_app_website, |
| food_safety_endpoint, |
| legal_endpoint, |
| emergency_crisis, |
| apply_discount, |
| loyalty_program, |
| competitor_mentions, |
| take_order, |
| ) |
|
|
| def generate_response(messages: list, system_prompt: str) -> str: |
| full = [{"role": "system", "content": system_prompt}] + messages |
| inputs = _processor.apply_chat_template( |
| full, |
| tools=[assistant_capabilities, call, validate, clarify_intent, store_policy, |
| store_information, store_app_website, food_safety_endpoint, legal_endpoint, |
| emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| add_generation_prompt=True, |
| ).to(_model.device) |
| with __import__("torch").no_grad(): |
| out = _model.generate( |
| **inputs, |
| max_new_tokens=400, |
| temperature=0.7, |
| do_sample=True, |
| pad_token_id=_processor.tokenizer.eos_token_id, |
| ) |
| new_tokens = out[0][inputs["input_ids"].shape[1]:] |
| return _processor.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
|
|
| def generate_response_stream(messages: list, system_prompt: str): |
| full = [{"role": "system", "content": system_prompt}] + messages |
| inputs = _processor.apply_chat_template( |
| full, |
| tools=[assistant_capabilities, call, validate, clarify_intent, store_policy, |
| store_information, store_app_website, food_safety_endpoint, legal_endpoint, |
| emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order], |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| add_generation_prompt=True, |
| ).to(_model.device) |
|
|
| from transformers import TextIteratorStreamer |
|
|
| streamer = TextIteratorStreamer(_processor.tokenizer, skip_prompt=True, skip_special_tokens=False) |
| thread = threading.Thread( |
| target=_model.generate, |
| kwargs={ |
| **inputs, |
| "max_new_tokens": 400, |
| "temperature": 0.7, |
| "do_sample": True, |
| "pad_token_id": _processor.tokenizer.eos_token_id, |
| "streamer": streamer, |
| }, |
| daemon=True, |
| ) |
| thread.start() |
| generated = "" |
| for chunk in streamer: |
| generated += chunk |
| yield generated |
| thread.join() |