DerivedFunction1's picture
add
6a3abc5
import os
import re
import json
import base64
import threading
from pathlib import Path
from typing import Any
# Constants from demo.py
BASE_DIR = Path(".")
HF_TOKEN_PATH = BASE_DIR / "hf_token"
HF_TOKEN = HF_TOKEN_PATH.read_text(encoding="utf-8").strip() or None
if HF_TOKEN is not None:
from huggingface_hub import login
login(token=HF_TOKEN, add_to_git_credential=False)
HF_MODEL = os.environ.get("HF_MODEL", "google/gemma-4-E2B-it")
JAILBREAK_MODEL = os.environ.get("JAILBREAK_MODEL", "DerivedFunction1/xlmr-prompt-injection")
JAILBREAK_THRESHOLD = float(os.environ.get("JAILBREAK_THRESHOLD", "0.5"))
REFUSAL_LANGUAGE_MODEL = os.environ.get(
"REFUSAL_LANGUAGE_MODEL",
"polyglot-tagger/multilabel-language-identification",
)
SUPPORTED_GEMMA_LANGS = {
"EN", "ES", "FR", "DE", "IT", "PT", "NL",
"DA", "RU", "PL",
"ZH", "JA", "KO", "VI",
"HI", "BN", "TH", "ID", "MS", "MR", "TE", "TA", "GU", "PA",
"AR", "TR", "HE", "SW",
}
SUPPORTED_JAILBREAK_LANGS = {
"EN", "AR", "DE", "ES", "FR", "HI", "IT", "JA", "KO", "NL", "TH", "ZH",
}
# Imports for model loading
from transformers import AutoProcessor, Gemma4ForConditionalGeneration, BitsAndBytesConfig, pipeline
# Model loading
print(f"Loading model: {HF_MODEL}")
_processor = AutoProcessor.from_pretrained(HF_MODEL, padding_side="left")
_bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True,
)
_model = Gemma4ForConditionalGeneration.from_pretrained(
HF_MODEL,
quantization_config=_bnb_config,
device_map="auto",
)
print(f"Loading jailbreak detector: {JAILBREAK_MODEL}")
_jailbreak_pipe = pipeline("text-classification", model=JAILBREAK_MODEL)
print(f"Loading refusal language detector: {REFUSAL_LANGUAGE_MODEL}")
_refusal_language_pipe = pipeline("text-classification", model=REFUSAL_LANGUAGE_MODEL)
# Tool call regex and markup stripping (from demo.py)
TOOL_CALL_RE = re.compile(
r"(?:<\|?tool_call\|?>|^)\s*"
r"(?:call:)?(?P<name>[a-zA-Z_]\w*)\s*"
r"(?:\{|\()(?P<args>.*?)(?:\}|\))\s*"
r"(?P<close><\|?tool_call\|?>|<eos>|<end_of_turn>|<turn\|?>|</s>|$)",
re.DOTALL,
)
TOOL_CALL_MARKUP_RE = re.compile(
r"<\|?tool_call\|?>.*?(?:<\|?tool_call\|?>|<eos>|$)",
re.DOTALL,
)
def _strip_tool_call_markup(text: str) -> str:
cleaned = (text or "").replace("\r", "").strip()
if not cleaned:
return ""
cleaned = cleaned.replace("<|\"|>", '"')
cleaned = TOOL_CALL_MARKUP_RE.sub("", cleaned)
cleaned = re.sub(r"<\|?tool_response\|?>.*$", "", cleaned, flags=re.DOTALL)
# Remove various special tokens and the REDIRECT token if present
cleaned = cleaned.replace("<|turn>", "").replace("<turn|>", "").replace("<eos>", "").replace("</s>", "")
cleaned = cleaned.replace("[REDIRECT]:", "")
return cleaned.strip()
def detect_jailbreak(text: str) -> dict:
"""Return detector metadata for a user message."""
result = _jailbreak_pipe(text, truncation=True, max_length=512)[0]
label = str(result.get("label", "")).lower()
score = float(result.get("score", 0.0))
unsafe_score = score if label == "unsafe" else (1.0 - score if label == "safe" else score)
return {
"score": unsafe_score,
"blocked": unsafe_score >= JAILBREAK_THRESHOLD,
"predicted_label": label,
}
def detect_refusal_language(text: str) -> str:
result = _refusal_language_pipe(text, truncation=True, max_length=512)[0]
label = str(result.get("label", "")).upper().strip()
if label in SUPPORTED_GEMMA_LANGS:
return label
return "EN"
def detect_preferred_language(text: str) -> str:
result = _refusal_language_pipe(text, truncation=True, max_length=512)[0]
label = str(result.get("label", "")).upper().strip()
return label or "EN"
def _sanitize_display_text(text: str, system_prompt: str | None = None) -> str:
cleaned = _strip_tool_call_markup(text)
if not cleaned:
return ""
return cleaned.strip()
# These imports are needed for generate_response and generate_response_stream
# They are imported here to avoid circular dependencies with demo.py
from bob_resources import (
assistant_capabilities,
call,
validate,
clarify_intent,
store_policy,
store_information,
store_app_website,
food_safety_endpoint,
legal_endpoint,
emergency_crisis,
apply_discount,
loyalty_program,
competitor_mentions,
take_order,
)
def generate_response(messages: list, system_prompt: str) -> str:
full = [{"role": "system", "content": system_prompt}] + messages
inputs = _processor.apply_chat_template(
full,
tools=[assistant_capabilities, call, validate, clarify_intent, store_policy,
store_information, store_app_website, food_safety_endpoint, legal_endpoint,
emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order],
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(_model.device)
with __import__("torch").no_grad():
out = _model.generate( # pyright: ignore[reportAttributeAccessIssue]
**inputs,
max_new_tokens=400,
temperature=0.7,
do_sample=True,
pad_token_id=_processor.tokenizer.eos_token_id,
)
new_tokens = out[0][inputs["input_ids"].shape[1]:]
return _processor.decode(new_tokens, skip_special_tokens=True).strip()
def generate_response_stream(messages: list, system_prompt: str):
full = [{"role": "system", "content": system_prompt}] + messages
inputs = _processor.apply_chat_template(
full,
tools=[assistant_capabilities, call, validate, clarify_intent, store_policy,
store_information, store_app_website, food_safety_endpoint, legal_endpoint,
emergency_crisis, apply_discount, loyalty_program, competitor_mentions, take_order],
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(_model.device)
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(_processor.tokenizer, skip_prompt=True, skip_special_tokens=False)
thread = threading.Thread(
target=_model.generate, # pyright: ignore[reportAttributeAccessIssue]
kwargs={
**inputs,
"max_new_tokens": 400,
"temperature": 0.7,
"do_sample": True,
"pad_token_id": _processor.tokenizer.eos_token_id,
"streamer": streamer,
},
daemon=True,
)
thread.start()
generated = ""
for chunk in streamer:
generated += chunk
yield generated
thread.join()