wchen22's picture
Add exact tokenizer accounting to compression API
1212e7d verified
from __future__ import annotations
import copy
import gzip
import hashlib
import json
import math
import os
import re
import threading
import time
import uuid
from functools import lru_cache
from pathlib import Path
from typing import Any
from fastapi import FastAPI, HTTPException, Request, Response
CLASSIFIER_MODEL = "microsoft/deberta-v3-small"
CLASSIFIER_ARTIFACT_DIR = os.environ.get("TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR")
API_SCHEMA_VERSION = "0.1.0"
RULES_VERSION = "hf-space-rules-v0.1.0"
REQUEST_ID_HEADER = "X-Request-ID"
IDEMPOTENCY_KEY_HEADER = "Idempotency-Key"
CONTENT_ENCODING_HEADER = "Content-Encoding"
TOUCHDOWN_CONTENT_ENCODING_HEADER = "X-Touchdown-Content-Encoding"
ACCEPT_ENCODING_HEADER = "Accept-Encoding"
GZIP_ENCODING = "gzip"
GZIP_MAGIC = b"\x1f\x8b"
DEFAULT_IDEMPOTENCY_TTL_SECONDS = 24 * 60 * 60
IDEMPOTENCY_TTL_ENV = "TOUCHDOWN_IDEMPOTENCY_TTL_SECONDS"
DEFAULT_PROTECTED_MESSAGE_ROLES = ("system", "developer")
LOW_SIGNAL_PATTERNS = [
re.compile(pattern, re.IGNORECASE)
for pattern in [
r"\blet me know if you (need|want|have)\b",
r"\bi hope this helps\b",
r"\bas an ai language model\b",
r"\bof course[,!]?\s*$",
r"\bcertainly[,!]?\s*$",
r"\bsure[,!]?\s*$",
]
]
app = FastAPI(title="Touchdown Compression Classifier", version="0.1.0")
@app.middleware("http")
async def transport_middleware(request: Request, call_next):
request_id = request.headers.get(REQUEST_ID_HEADER) or f"tdreq_{uuid.uuid4().hex}"
try:
body, request_was_gzip = _decode_request_body(
await request.body(),
request.headers.get(CONTENT_ENCODING_HEADER)
or request.headers.get(TOUCHDOWN_CONTENT_ENCODING_HEADER),
)
except UnsupportedContentEncoding as exc:
response = Response(
content=json.dumps({
"status": "error",
"error": str(exc),
"request_id": request_id,
}),
status_code=415,
media_type="application/json",
headers={REQUEST_ID_HEADER: request_id},
)
return await _gzip_response_if_needed(
request,
response,
request_was_gzip=False,
)
except ValueError as exc:
response = Response(
content=json.dumps({
"status": "error",
"error": str(exc),
"request_id": request_id,
}),
status_code=400,
media_type="application/json",
headers={REQUEST_ID_HEADER: request_id},
)
return await _gzip_response_if_needed(
request,
response,
request_was_gzip=False,
)
body_sent = False
async def receive() -> dict[str, Any]:
nonlocal body_sent
if body_sent:
return {"type": "http.request", "body": b"", "more_body": False}
body_sent = True
return {"type": "http.request", "body": body, "more_body": False}
scope = dict(request.scope)
scope["headers"] = [
(
key,
str(len(body)).encode("latin-1")
if key.lower() == b"content-length" else value,
)
for key, value in request.scope.get("headers", [])
if key.lower() not in (b"content-encoding", b"x-touchdown-content-encoding")
]
request = Request(scope, receive)
request.state.request_id = request_id
response = await call_next(request)
response.headers[REQUEST_ID_HEADER] = request_id
return await _gzip_response_if_needed(
request,
response,
request_was_gzip=request_was_gzip,
)
class UnsupportedContentEncoding(ValueError):
"""Raised when an HTTP request uses an unsupported content encoding."""
_IDEMPOTENCY_LOCK = threading.Lock()
_IDEMPOTENCY_CACHE: dict[tuple[str, str, str], dict[str, Any]] = {}
def _content_encoding(value: str | None) -> str:
return (value or "identity").split(";", 1)[0].strip().lower() or "identity"
def _decode_request_body(raw: bytes, content_encoding: str | None) -> tuple[bytes, bool]:
encoding = _content_encoding(content_encoding)
if encoding == "identity":
if raw.startswith(GZIP_MAGIC):
try:
return gzip.decompress(raw), True
except OSError as exc:
raise ValueError("invalid gzip request body") from exc
return raw, False
if encoding != GZIP_ENCODING:
raise UnsupportedContentEncoding(
f"unsupported content encoding: {content_encoding}"
)
try:
return gzip.decompress(raw), True
except OSError as exc:
raise ValueError("invalid gzip request body") from exc
def _accepts_gzip(accept_encoding: str | None) -> bool:
for part in (accept_encoding or "").split(","):
normalized = part.replace(" ", "").lower()
token = normalized.split(";", 1)[0]
if token in (GZIP_ENCODING, "*") and "q=0" not in normalized:
return True
return False
def _should_gzip_response(
*,
accept_encoding: str | None,
request_was_gzip: bool,
) -> bool:
return request_was_gzip or _accepts_gzip(accept_encoding)
def _idempotency_ttl_seconds() -> int:
raw = os.environ.get(IDEMPOTENCY_TTL_ENV)
if raw is None:
return DEFAULT_IDEMPOTENCY_TTL_SECONDS
try:
parsed = int(raw)
except ValueError:
return DEFAULT_IDEMPOTENCY_TTL_SECONDS
return max(0, parsed)
def _idempotency_key_from_payload(payload: dict[str, Any]) -> str | None:
value = payload.get("idempotency_key")
return value if isinstance(value, str) and value else None
def _fingerprint_value(value: Any) -> Any:
if isinstance(value, dict):
return {
key: _fingerprint_value(item)
for key, item in sorted(value.items())
if key not in {"request_id", "idempotency_key"}
}
if isinstance(value, list):
return [_fingerprint_value(item) for item in value]
return value
def _idempotency_fingerprint(route: str, payload: dict[str, Any]) -> str:
encoded = json.dumps(
{
"route": route,
"payload": _fingerprint_value(payload),
},
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
)
return hashlib.sha256(encoded.encode("utf-8")).hexdigest()
def _walk_receipts(value: Any) -> list[dict[str, Any]]:
receipts: list[dict[str, Any]] = []
if isinstance(value, dict):
receipt = value.get("receipt")
if isinstance(receipt, dict):
receipts.append(receipt)
for item in value.values():
receipts.extend(_walk_receipts(item))
elif isinstance(value, list):
for item in value:
receipts.extend(_walk_receipts(item))
return receipts
def _mark_idempotency(
body: dict[str, Any],
*,
key: str,
fingerprint: str,
replayed: bool,
original_request_id: str | None,
) -> dict[str, Any]:
ttl_seconds = _idempotency_ttl_seconds()
marked = copy.deepcopy(body)
marked["idempotency_key"] = key
marked["idempotency_mode"] = "in_memory_response_replay"
marked["idempotency_replayed"] = replayed
marked["idempotency_fingerprint"] = fingerprint
marked["idempotency_ttl_seconds"] = ttl_seconds
if original_request_id:
marked["idempotency_original_request_id"] = original_request_id
for receipt in _walk_receipts(marked):
if receipt.get("idempotency_key") in (None, key):
receipt["idempotency_key"] = key
receipt["idempotency_mode"] = "in_memory_response_replay"
receipt["idempotency_replayed"] = replayed
receipt["idempotency_fingerprint"] = fingerprint
receipt["idempotency_ttl_seconds"] = ttl_seconds
if original_request_id:
receipt["idempotency_original_request_id"] = original_request_id
return marked
def _cached_idempotency_body(
*,
route: str,
key: str,
fingerprint: str,
) -> dict[str, Any] | None:
now = time.monotonic()
with _IDEMPOTENCY_LOCK:
expired = [
cache_key for cache_key, record in _IDEMPOTENCY_CACHE.items()
if float(record["expires_at"]) <= now
]
for cache_key in expired:
_IDEMPOTENCY_CACHE.pop(cache_key, None)
prefix = (route, key)
for cache_key, record in _IDEMPOTENCY_CACHE.items():
if cache_key[:2] != prefix:
continue
if record["fingerprint"] != fingerprint:
raise HTTPException(status_code=409, detail={
"schema_version": API_SCHEMA_VERSION,
"status": "error",
"error": "idempotency_key_reused_with_different_payload",
"idempotency_key": key,
"idempotency_fingerprint": fingerprint,
"idempotency_existing_fingerprint": record["fingerprint"],
})
return _mark_idempotency(
record["body"],
key=key,
fingerprint=fingerprint,
replayed=True,
original_request_id=record.get("request_id"),
)
return None
def _store_idempotency_body(
*,
route: str,
key: str,
fingerprint: str,
body: dict[str, Any],
request_id: str | None,
) -> dict[str, Any]:
marked = _mark_idempotency(
body,
key=key,
fingerprint=fingerprint,
replayed=False,
original_request_id=request_id,
)
ttl_seconds = _idempotency_ttl_seconds()
if ttl_seconds > 0:
with _IDEMPOTENCY_LOCK:
_IDEMPOTENCY_CACHE[(route, key, fingerprint)] = {
"fingerprint": fingerprint,
"body": marked,
"request_id": request_id,
"expires_at": time.monotonic() + ttl_seconds,
}
return marked
def _handle_compress_with_idempotency(payload: dict[str, Any]) -> dict[str, Any]:
key = _idempotency_key_from_payload(payload)
if not key:
return _dispatch_compress(payload)
route = "/v1/compress"
fingerprint = _idempotency_fingerprint(route, payload)
cached = _cached_idempotency_body(
route=route,
key=key,
fingerprint=fingerprint,
)
if cached is not None:
return cached
body = _dispatch_compress(payload)
request_id = payload.get("request_id") if isinstance(payload.get("request_id"), str) else None
return _store_idempotency_body(
route=route,
key=key,
fingerprint=fingerprint,
body=body,
request_id=request_id,
)
async def _gzip_response_if_needed(
request: Request,
response: Response,
*,
request_was_gzip: bool,
) -> Response:
should_gzip = _should_gzip_response(
accept_encoding=request.headers.get(ACCEPT_ENCODING_HEADER),
request_was_gzip=request_was_gzip,
)
if not should_gzip or response.headers.get(CONTENT_ENCODING_HEADER):
return response
if hasattr(response, "body_iterator"):
body = b""
async for chunk in response.body_iterator:
body += chunk
else:
body = getattr(response, "body", b"")
compressed = gzip.compress(body)
headers = {
key: value
for key, value in response.headers.items()
if key.lower() not in ("content-length", "content-encoding")
}
headers[CONTENT_ENCODING_HEADER] = GZIP_ENCODING
headers["Vary"] = ACCEPT_ENCODING_HEADER
headers["Content-Length"] = str(len(compressed))
return Response(
content=compressed,
status_code=response.status_code,
headers=headers,
background=getattr(response, "background", None),
)
@lru_cache(maxsize=1)
def _get_tokenizer():
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(CLASSIFIER_MODEL)
@lru_cache(maxsize=8)
def _get_count_tokenizer(model_name: str):
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(model_name)
@lru_cache(maxsize=1)
def _classifier_manifest() -> dict[str, Any] | None:
if not CLASSIFIER_ARTIFACT_DIR:
return None
path = Path(CLASSIFIER_ARTIFACT_DIR) / "classifier_manifest.json"
if not path.exists():
return {
"status": "manifest_error",
"error": f"missing classifier manifest: {path}",
}
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except Exception as exc:
return {"status": "manifest_error", "error": str(exc)}
if not isinstance(payload, dict):
return {"status": "manifest_error", "error": "manifest must be an object"}
return payload
@lru_cache(maxsize=1)
def _get_artifact_tokenizer():
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(str(CLASSIFIER_ARTIFACT_DIR))
@lru_cache(maxsize=1)
def _get_onnx_session(model_path: str):
import onnxruntime as ort
return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
def _find_spans(text: str, needle: str) -> list[tuple[int, int]]:
spans = []
cursor = 0
while needle:
index = text.find(needle, cursor)
if index < 0:
return spans
spans.append((index, index + len(needle)))
cursor = index + len(needle)
return spans
def _regex_spans(text: str, pattern: str, flags: int) -> list[tuple[int, int]]:
return [(match.start(), match.end()) for match in re.finditer(pattern, text, flags)]
def _json_line_spans(text: str) -> list[tuple[int, int]]:
spans = []
cursor = 0
for line in text.splitlines(keepends=True):
stripped = line.strip()
if stripped and stripped[0] in "[{":
try:
import json
json.loads(stripped)
except Exception:
pass
else:
offset = line.index(stripped[0])
spans.append((cursor + offset, cursor + offset + len(stripped)))
cursor += len(line)
return spans
def _merge(spans: list[tuple[int, int]]) -> list[tuple[int, int]]:
if not spans:
return []
ordered = sorted(spans)
merged = [ordered[0]]
for start, end in ordered[1:]:
prev_start, prev_end = merged[-1]
if start <= prev_end:
merged[-1] = (prev_start, max(prev_end, end))
else:
merged.append((start, end))
return merged
def _overlaps(spans: list[tuple[int, int]], start: int, end: int) -> bool:
return any(start < span_end and end > span_start for span_start, span_end in spans)
def _manifest_drop_labels(text: str) -> tuple[str, list[dict[str, Any]], str | None]:
manifest = _classifier_manifest()
if not manifest:
return "tokenizer_loaded_no_trained_keep_drop_head", [], None
if manifest.get("status") == "manifest_error":
return "manifest_error", [], str(manifest.get("error"))
onnx_model = Path(CLASSIFIER_ARTIFACT_DIR or "") / str(
manifest.get("onnx_model", "model.onnx")
)
if onnx_model.exists():
try:
return (
str(manifest.get("status") or "trained_keep_drop_head"),
_onnx_labels(text, manifest, onnx_model),
None,
)
except Exception as exc:
return "local_artifact_error", [], str(exc)
labels: list[dict[str, Any]] = []
for index, entry in enumerate(manifest.get("drop_phrases") or []):
if isinstance(entry, str):
phrase = entry
score = 1.0
elif isinstance(entry, dict):
phrase = str(entry.get("text") or "")
score = float(entry.get("score", 1.0))
else:
continue
cursor = 0
while phrase:
start = text.find(phrase, cursor)
if start < 0:
break
end = start + len(phrase)
labels.append({
"token": phrase,
"label": "DROP",
"score": score,
"start": start,
"end": end,
"source": "manifest_drop_phrase",
"classifier_index": index,
})
cursor = end
return str(manifest.get("status") or "artifact_manifest_loaded"), labels, None
def _softmax(values: list[float]) -> list[float]:
peak = max(values)
exps = [math.exp(value - peak) for value in values]
total = sum(exps) or 1.0
return [value / total for value in exps]
def _onnx_labels(
text: str,
manifest: dict[str, Any],
model_path: Path,
) -> list[dict[str, Any]]:
import numpy as np
tokenizer = _get_artifact_tokenizer()
encoded = tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
return_tensors="np",
truncation=True,
max_length=int(manifest.get("max_length", 512)),
stride=int(manifest.get("stride", 0)),
return_overflowing_tokens=True,
padding=True,
)
offsets = encoded.pop("offset_mapping")
input_ids = encoded["input_ids"]
session = _get_onnx_session(str(model_path))
input_names = {item.name for item in session.get_inputs()}
inputs = {key: value for key, value in encoded.items() if key in input_names}
logits = session.run(None, inputs)[0]
id2label = {
str(key): value
for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items()
}
label2id = {str(value).upper(): int(key) for key, value in id2label.items()}
keep_id = label2id.get("KEEP", 0)
drop_id = label2id.get("DROP", 1)
drop_score_threshold = float(manifest.get("drop_score_threshold", 0.5))
best_by_span: dict[tuple[int, int], dict[str, Any]] = {}
for chunk_index, chunk_logits in enumerate(logits):
tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index])
for token_index, token_logits in enumerate(chunk_logits):
probs = _softmax(np.asarray(token_logits, dtype=float).tolist())
keep_score = float(probs[keep_id]) if keep_id < len(probs) else 0.0
drop_score = float(probs[drop_id]) if drop_id < len(probs) else 0.0
label_id = drop_id if drop_score >= drop_score_threshold else keep_id
start, end = offsets[chunk_index][token_index]
start = int(start)
end = int(end)
if end <= start:
continue
key = (start, end)
item = {
"token": tokens[token_index],
"label": id2label.get(str(label_id), "KEEP"),
"score": round(drop_score if label_id == drop_id else keep_score, 6),
"keep_score": round(keep_score, 6),
"drop_score": round(drop_score, 6),
"drop_score_threshold": drop_score_threshold,
"start": start,
"end": end,
"source": "onnx_token_classifier",
"chunk_index": chunk_index,
}
if drop_score > float(best_by_span.get(key, {}).get("drop_score", -1.0)):
best_by_span[key] = item
return [best_by_span[key] for key in sorted(best_by_span)]
def _safe_classifier_drop_ranges(
labels: list[dict[str, Any]],
protected: list[tuple[int, int]],
*,
text_length: int,
min_score: float,
) -> tuple[list[tuple[int, int]], int, int, dict[str, int]]:
blocked_by_reason = {
"malformed_or_missing_offsets": 0,
"out_of_bounds": 0,
"below_min_score": 0,
"protected_span_overlap": 0,
}
ranges: list[tuple[int, int]] = []
drop_labels = 0
for item in labels:
raw_drop_score = item.get("drop_score")
try:
drop_score = float(raw_drop_score) if raw_drop_score is not None else None
except (TypeError, ValueError):
drop_score = None
if (
str(item.get("label") or item.get("entity") or "").upper() != "DROP"
and (drop_score is None or drop_score < min_score)
):
continue
drop_labels += 1
try:
start = int(item["start"])
end = int(item["end"])
score = drop_score if drop_score is not None else float(item.get("score", 1.0))
except Exception:
blocked_by_reason["malformed_or_missing_offsets"] += 1
continue
if start < 0 or end > text_length or end <= start:
blocked_by_reason["out_of_bounds"] += 1
continue
if score < min_score:
blocked_by_reason["below_min_score"] += 1
continue
if _overlaps(protected, start, end):
blocked_by_reason["protected_span_overlap"] += 1
continue
ranges.append((start, end))
merged = _merge(ranges)
return merged, drop_labels, len(merged), blocked_by_reason
def _is_subsequence(candidate: str, original: str) -> bool:
cursor = 0
for char in candidate:
cursor = original.find(char, cursor)
if cursor < 0:
return False
cursor += 1
return True
def _sha256_text(value: str) -> str:
return hashlib.sha256(value.encode("utf-8")).hexdigest()
def _receipt_id(payload: dict[str, Any]) -> str:
encoded = json.dumps(payload, sort_keys=True, separators=(",", ":"))
return "tdcr_" + hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:24]
def _stable_json_sha256(value: Any) -> str:
encoded = json.dumps(
value,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
)
return hashlib.sha256(encoded.encode("utf-8")).hexdigest()
def _aggregate_receipt_id(payload: dict[str, Any]) -> str:
encoded = json.dumps(payload, sort_keys=True, separators=(",", ":"))
return "tdcm_" + hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:24]
def _string_set(
value: Any,
*,
default: tuple[str, ...] = (),
field_name: str,
) -> set[str]:
if value is None:
return set(default)
if (
isinstance(value, list)
and all(isinstance(item, str) and item for item in value)
):
return {item.lower() for item in value}
raise HTTPException(status_code=400, detail=f"{field_name} must be a list of strings")
def _message_role(message: dict[str, Any], index: int) -> str:
role = message.get("role")
if not isinstance(role, str) or not role:
raise HTTPException(
status_code=400,
detail=f"messages[{index}].role must be a string",
)
return role.lower()
def _message_decision(
*,
tokens_saved: int,
receipts: list[dict[str, Any]],
) -> str:
decisions = [
receipt.get("decision")
for receipt in receipts
if isinstance(receipt, dict)
]
if any(decision == "reject" for decision in decisions):
return "reject"
if any(decision == "needs_review" for decision in decisions):
return "needs_review"
if tokens_saved <= 0:
return "no_op"
return "high_confidence"
def _correlation_payload(
payload: dict[str, Any],
*,
request_id: str | None = None,
idempotency_key: str | None = None,
) -> dict[str, Any]:
out = dict(payload)
if request_id and "request_id" not in out:
out["request_id"] = request_id
if idempotency_key and "idempotency_key" not in out:
out["idempotency_key"] = idempotency_key
return out
def _tool_schema_variants(schema: Any) -> list[str]:
if isinstance(schema, str):
return [schema] if schema else []
if not isinstance(schema, (dict, list)):
raise HTTPException(
status_code=400,
detail="tool_schemas entries must be strings, objects, or arrays",
)
variants: list[str] = []
for kwargs in (
{"ensure_ascii": False, "sort_keys": False, "separators": (",", ":")},
{"ensure_ascii": False, "sort_keys": True, "separators": (",", ":")},
{"ensure_ascii": False, "sort_keys": False, "indent": 2},
{"ensure_ascii": False, "sort_keys": True, "indent": 2},
):
encoded = json.dumps(schema, **kwargs)
if encoded not in variants:
variants.append(encoded)
return variants
def _tool_schema_groups(tool_schemas: Any) -> list[list[str]]:
if tool_schemas is None:
return []
if not isinstance(tool_schemas, list):
raise HTTPException(status_code=400, detail="tool_schemas must be a list")
return [_tool_schema_variants(schema) for schema in tool_schemas]
def _schema_occurrences(text: str, group: list[str]) -> int:
return sum(len(_find_spans(text, candidate)) for candidate in group)
def _tool_schema_missing_groups(
input_text: str,
output: str,
groups: list[list[str]],
) -> list[list[str]]:
missing = []
for group in groups:
input_occurrences = _schema_occurrences(input_text, group)
output_occurrences = _schema_occurrences(output, group)
if input_occurrences > 0 and output_occurrences == 0:
missing.append(group)
return missing
def _optional_string(payload: dict[str, Any], key: str) -> str | None:
value = payload.get(key)
if value is None:
return None
if not isinstance(value, str):
raise HTTPException(status_code=400, detail=f"{key} must be a string")
return value
def _count_tokens(text: str, tokenizer_model: str | None) -> tuple[int, bool, str]:
if tokenizer_model:
try:
tokenizer = _get_count_tokenizer(tokenizer_model)
return (
len(tokenizer.encode(text, add_special_tokens=False)),
True,
"tokenizer",
)
except Exception:
pass
return max(1, round(len(text) / 4.0)), False, "chars_per_token_estimate"
def _protected_spans(
text: str,
protected_values: list[str],
tool_schema_groups: list[list[str]],
*,
preserve_code: bool,
preserve_json: bool,
) -> tuple[
list[tuple[int, int]],
list[tuple[int, int]],
list[tuple[int, int]],
list[tuple[int, int]],
list[tuple[int, int]],
]:
spans = []
for value in protected_values:
spans.extend(_find_spans(text, value))
tool_schema_spans = []
for group in tool_schema_groups:
for candidate in group:
tool_schema_spans.extend(_find_spans(text, candidate))
spans.extend(tool_schema_spans)
code_spans = (
_regex_spans(text, r"```.*?```", re.MULTILINE | re.DOTALL)
if preserve_code
else []
)
json_spans = (
_regex_spans(text, r"```(?:json|JSON)\s*\n.*?```", re.MULTILINE | re.DOTALL)
+ _json_line_spans(text)
if preserve_json
else []
)
system_spans = (
_regex_spans(text, r"<system\b[^>]*>.*?</system>", re.MULTILINE | re.DOTALL)
+ _regex_spans(text, r"^system\s*:\s*.*$", re.MULTILINE | re.IGNORECASE)
)
spans.extend(code_spans)
spans.extend(json_spans)
spans.extend(system_spans)
return (
_merge(spans),
code_spans,
json_spans,
system_spans,
_merge(tool_schema_spans),
)
def _compress_text(payload: dict[str, Any]) -> dict[str, Any]:
text = payload.get("input")
if not isinstance(text, str):
raise HTTPException(status_code=400, detail="input must be a string")
settings = payload.get("compression_settings") or {}
if not isinstance(settings, dict):
raise HTTPException(status_code=400, detail="compression_settings must be an object")
request_id = payload.get("request_id")
if request_id is not None and not isinstance(request_id, str):
raise HTTPException(status_code=400, detail="request_id must be a string")
idempotency_key = payload.get("idempotency_key")
if idempotency_key is not None and not isinstance(idempotency_key, str):
raise HTTPException(status_code=400, detail="idempotency_key must be a string")
tokenizer_model = _optional_string(payload, "tokenizer_model")
protected_values = payload.get("protected_spans") or []
if not isinstance(protected_values, list) or not all(
isinstance(value, str) for value in protected_values
):
raise HTTPException(status_code=400, detail="protected_spans must be strings")
tool_schema_groups = _tool_schema_groups(
payload.get("tool_schemas", payload.get("tools"))
)
try:
aggressiveness = float(settings.get("aggressiveness", 0.35))
except Exception as exc:
raise HTTPException(status_code=400, detail="aggressiveness must be numeric") from exc
if not 0.0 <= aggressiveness <= 1.0:
raise HTTPException(status_code=400, detail="aggressiveness must be 0..1")
started = time.perf_counter()
preserve_code = bool(settings.get("preserve_code", True))
preserve_json = bool(settings.get("preserve_json", True))
protected, code_spans, json_spans, system_spans, tool_schema_spans = _protected_spans(
text,
protected_values,
tool_schema_groups,
preserve_code=preserve_code,
preserve_json=preserve_json,
)
classifier_status, classifier_labels, classifier_error = _manifest_drop_labels(text)
classifier_ranges, classifier_drop_labels, classifier_applied, classifier_blocked_by_reason = (
_safe_classifier_drop_ranges(
classifier_labels,
protected,
text_length=len(text),
min_score=float(settings.get("classifier_min_drop_score", 0.5)),
)
)
seen: set[str] = set()
drops: list[tuple[int, int, str, str]] = []
cursor = 0
previous_blank_kept = False
for line in text.splitlines(keepends=True):
start = cursor
end = cursor + len(line)
cursor = end
stripped = line.strip()
normalized = re.sub(r"\s+", " ", stripped).lower()
if _overlaps(protected, start, end):
if normalized:
seen.add(normalized)
previous_blank_kept = False
continue
reason = None
if not stripped and previous_blank_kept:
reason = "extra_blank"
elif len(normalized) > 30 and normalized in seen and aggressiveness >= 0.10:
reason = "duplicate_line"
elif aggressiveness >= 0.20 and any(
pattern.search(stripped) for pattern in LOW_SIGNAL_PATTERNS
):
reason = "low_signal_phrase"
if reason:
drops.append((start, end, reason, stripped[:120]))
continue
previous_blank_kept = not stripped
if normalized:
seen.add(normalized)
chunks = []
drop_ranges = _merge([(start, end) for start, end, _, _ in drops] + classifier_ranges)
cursor = 0
for start, end in drop_ranges:
chunks.append(text[cursor:start])
cursor = end
chunks.append(text[cursor:])
output = "".join(chunks)
before, before_exact, token_method = _count_tokens(text, tokenizer_model)
after, after_exact, after_method = _count_tokens(output, tokenizer_model)
token_count_exact = before_exact and after_exact
if after_method != token_method:
token_method = "chars_per_token_estimate"
saved = max(0, before - after)
missing = [value for value in protected_values if value and value not in output]
code_preserved = all(text[start:end] in output for start, end in code_spans)
json_preserved = all(text[start:end] in output for start, end in json_spans)
system_preserved = all(text[start:end] in output for start, end in system_spans)
missing_tool_schemas = _tool_schema_missing_groups(
text,
output,
tool_schema_groups,
)
tool_schemas_preserved = not missing_tool_schemas
decision = "reject" if (
missing
or not code_preserved
or not json_preserved
or not system_preserved
or not tool_schemas_preserved
) else (
"high_confidence" if saved > 0 and aggressiveness <= 0.65 else "no_op"
)
dropped_segments = [
{"reason": reason, "preview": preview, "start": start, "end": end}
for start, end, reason, preview in drops[:20]
]
classifier_blocked = sum(classifier_blocked_by_reason.values())
receipt = {
"request_id": request_id,
"idempotency_key": idempotency_key,
"idempotency_mode": "correlation_only_no_response_replay",
"protected_spans_checked": len(protected_values),
"protected_spans_missing": len(missing),
"code_blocks_detected": len(code_spans),
"code_blocks_preserved": code_preserved,
"json_blocks_detected": len(json_spans),
"json_blocks_preserved": json_preserved,
"system_prompt_spans_detected": len(system_spans),
"system_prompts_preserved": system_preserved,
"tool_schemas_checked": len(tool_schema_groups),
"tool_schema_occurrences": sum(
_schema_occurrences(text, group)
for group in tool_schema_groups
),
"tool_schema_spans_detected": len(tool_schema_spans),
"tool_schemas_missing": len(missing_tool_schemas),
"tool_schemas_missing_preview": [
next((candidate[:120] for candidate in group if candidate), "")
for group in missing_tool_schemas
],
"tool_schemas_preserved": tool_schemas_preserved,
"decision": decision,
"compressor_latency_ms": round((time.perf_counter() - started) * 1000.0, 3),
"deletion_only": _is_subsequence(output, text),
"deterministic": True,
"rules_version": RULES_VERSION,
"classifier": {
"model": CLASSIFIER_MODEL,
"status": classifier_status,
"artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR),
"artifact_dir": CLASSIFIER_ARTIFACT_DIR,
"error": classifier_error,
"labels_received": len(classifier_labels),
"drop_labels": classifier_drop_labels,
"drop_spans_applied": classifier_applied,
"drop_spans_blocked_by_safety": classifier_blocked,
"drop_spans_blocked_by_reason": classifier_blocked_by_reason,
},
"removed_spans_count": len(drop_ranges),
"removed_chars": sum(end - start for start, end in drop_ranges),
"rule_drop_spans_applied": len(drops),
"rule_drop_chars": sum(end - start for start, end, _, _ in drops),
"classifier_drop_spans_applied": len(classifier_ranges),
"classifier_drop_chars": sum(end - start for start, end in classifier_ranges),
"dropped_segments_count": len(drops),
"dropped_segments": dropped_segments,
"token_count_exact": token_count_exact,
"token_count_method": token_method,
"tokenizer_model": tokenizer_model,
}
receipt["input_sha256"] = _sha256_text(text)
receipt["output_sha256"] = _sha256_text(output)
receipt["removed_sha256"] = _sha256_text(
"".join(text[start:end] for start, end in drop_ranges)
)
receipt["receipt_id"] = _receipt_id({
"input_sha256": receipt["input_sha256"],
"output_sha256": receipt["output_sha256"],
"removed_sha256": receipt["removed_sha256"],
"tokens_saved": saved,
"compression_percentage": round(100.0 * saved / before, 1),
"decision": decision,
"rules_version": RULES_VERSION,
"classifier": receipt["classifier"],
"dropped_segments": dropped_segments,
})
return {
"schema_version": API_SCHEMA_VERSION,
"status": "ok",
"endpoint": "/v1/compress",
"maturity": "measurement_only",
"output": output,
"original_input_tokens": before,
"output_tokens": after,
"tokens_saved": saved,
"compression_percentage": round(100.0 * saved / before, 1),
"request_id": request_id,
"idempotency_key": idempotency_key,
"receipt": receipt,
}
def _merge_batch_item_payload(
payload: dict[str, Any],
item: Any,
index: int,
) -> tuple[str | None, dict[str, Any]]:
if isinstance(item, str):
return None, {
"input": item,
"compression_settings": payload.get("compression_settings"),
"protected_spans": payload.get("protected_spans"),
"tool_schemas": payload.get("tool_schemas", payload.get("tools")),
"tokenizer_model": payload.get("tokenizer_model"),
"request_id": payload.get("request_id"),
"idempotency_key": payload.get("idempotency_key"),
}
if not isinstance(item, dict):
raise ValueError(f"inputs[{index}] must be a string or an object")
item_id = item.get("id")
if item_id is not None and not isinstance(item_id, str):
raise ValueError(f"inputs[{index}].id must be a string")
if "input" not in item:
raise ValueError(f"inputs[{index}].input is required")
top_settings = payload.get("compression_settings")
item_settings = item.get("compression_settings")
if top_settings is not None and not isinstance(top_settings, dict):
raise ValueError("compression_settings must be an object")
if item_settings is not None and not isinstance(item_settings, dict):
raise ValueError(f"inputs[{index}].compression_settings must be an object")
settings = {
**(top_settings or {}),
**(item_settings or {}),
} or None
return item_id, {
"input": item.get("input"),
"compression_settings": settings,
"protected_spans": item.get("protected_spans", payload.get("protected_spans")),
"tool_schemas": item.get(
"tool_schemas",
item.get("tools", payload.get("tool_schemas", payload.get("tools"))),
),
"tokenizer_model": item.get("tokenizer_model", payload.get("tokenizer_model")),
"request_id": item.get("request_id", payload.get("request_id")),
"idempotency_key": item.get(
"idempotency_key",
payload.get("idempotency_key"),
),
}
def _handle_batch(payload: dict[str, Any]) -> dict[str, Any]:
if "input" in payload:
raise HTTPException(status_code=400, detail="provide either input or inputs, not both")
inputs = payload.get("inputs")
if not isinstance(inputs, list):
raise HTTPException(status_code=400, detail="inputs must be a list")
if not inputs:
raise HTTPException(status_code=400, detail="inputs list is empty")
request_id = payload.get("request_id")
if request_id is not None and not isinstance(request_id, str):
raise HTTPException(status_code=400, detail="request_id must be a string")
idempotency_key = payload.get("idempotency_key")
if idempotency_key is not None and not isinstance(idempotency_key, str):
raise HTTPException(status_code=400, detail="idempotency_key must be a string")
results: list[dict[str, Any]] = []
totals = {
"original_input_tokens": 0,
"output_tokens": 0,
"tokens_saved": 0,
}
succeeded = 0
failed = 0
for index, item in enumerate(inputs):
item_result: dict[str, Any] = {"index": index}
if isinstance(item, dict) and isinstance(item.get("id"), str):
item_result["id"] = item["id"]
try:
item_id, item_payload = _merge_batch_item_payload(payload, item, index)
if item_id is not None:
item_result["id"] = item_id
result = _compress_text(item_payload)
except HTTPException as exc:
item_result.update({"status": "error", "error": str(exc.detail)})
failed += 1
results.append(item_result)
continue
except ValueError as exc:
item_result.update({"status": "error", "error": str(exc)})
failed += 1
results.append(item_result)
continue
item_result.update({"status": "ok", **result})
totals["original_input_tokens"] += int(result["original_input_tokens"])
totals["output_tokens"] += int(result["output_tokens"])
totals["tokens_saved"] += int(result["tokens_saved"])
succeeded += 1
results.append(item_result)
compression_pct = (
round(100.0 * totals["tokens_saved"] / totals["original_input_tokens"], 1)
if totals["original_input_tokens"]
else 0.0
)
receipt_ids = [
result["receipt"]["receipt_id"]
for result in results
if result.get("status") == "ok" and result.get("receipt", {}).get("receipt_id")
]
body = {
"schema_version": API_SCHEMA_VERSION,
"status": "ok" if failed == 0 else "partial_error",
"endpoint": "/v1/compress",
"maturity": "measurement_only",
"input_count": len(inputs),
"succeeded": succeeded,
"failed": failed,
**totals,
"compression_percentage": compression_pct,
"receipt_ids": receipt_ids,
"results": results,
}
if request_id is not None:
body["request_id"] = request_id
if idempotency_key is not None:
body["idempotency_key"] = idempotency_key
return body
def _handle_messages(payload: dict[str, Any]) -> dict[str, Any]:
if "input" in payload or "inputs" in payload:
raise HTTPException(status_code=400, detail="provide either messages, input, or inputs")
messages = payload.get("messages")
if not isinstance(messages, list) or not messages:
raise HTTPException(status_code=400, detail="messages must be a non-empty list")
if not all(isinstance(message, dict) for message in messages):
raise HTTPException(status_code=400, detail="messages entries must be objects")
settings = payload.get("compression_settings") or {}
if not isinstance(settings, dict):
raise HTTPException(status_code=400, detail="compression_settings must be an object")
protected_roles = _string_set(
settings.get("protected_roles"),
default=DEFAULT_PROTECTED_MESSAGE_ROLES,
field_name="compression_settings.protected_roles",
)
compress_roles = (
_string_set(
settings.get("compress_roles"),
field_name="compression_settings.compress_roles",
)
if "compress_roles" in settings else None
)
protected_values = payload.get("protected_spans") or []
if not isinstance(protected_values, list) or not all(
isinstance(value, str) for value in protected_values
):
raise HTTPException(status_code=400, detail="protected_spans must be strings")
request_id = payload.get("request_id")
if request_id is not None and not isinstance(request_id, str):
raise HTTPException(status_code=400, detail="request_id must be a string")
idempotency_key = payload.get("idempotency_key")
if idempotency_key is not None and not isinstance(idempotency_key, str):
raise HTTPException(status_code=400, detail="idempotency_key must be a string")
tokenizer_model = _optional_string(payload, "tokenizer_model")
output_messages: list[dict[str, Any]] = []
receipts: list[dict[str, Any]] = []
nested_receipts: list[dict[str, Any]] = []
receipt_ids: list[str] = []
original_tokens = 0
output_tokens = 0
compressed_message_count = 0
skipped_message_count = 0
for index, message in enumerate(messages):
role = _message_role(message, index)
content = message.get("content")
output_message = dict(message)
if not isinstance(content, str):
skipped_message_count += 1
receipts.append({
"index": index,
"role": role,
"status": "skipped",
"reason": "non_string_content",
"content_type": type(content).__name__,
"original_input_tokens": 0,
"output_tokens": 0,
"tokens_saved": 0,
})
output_messages.append(output_message)
continue
role_protected = role in protected_roles or (
compress_roles is not None and role not in compress_roles
)
item_protected = list(protected_values)
if role_protected and content:
item_protected.append(content)
result = _compress_text({
"input": content,
"compression_settings": settings,
"protected_spans": item_protected,
"tool_schemas": payload.get("tool_schemas", payload.get("tools")),
"tokenizer_model": tokenizer_model,
"request_id": request_id,
"idempotency_key": idempotency_key,
})
output_message["content"] = result["output"]
output_messages.append(output_message)
original_tokens += int(result["original_input_tokens"])
output_tokens += int(result["output_tokens"])
tokens_saved = int(result["tokens_saved"])
if tokens_saved > 0:
compressed_message_count += 1
receipt = result["receipt"]
nested_receipts.append(receipt)
receipt_ids.append(receipt["receipt_id"])
receipts.append({
"index": index,
"role": role,
"status": "ok",
"protected_by_role": role_protected,
"original_input_tokens": result["original_input_tokens"],
"output_tokens": result["output_tokens"],
"tokens_saved": tokens_saved,
"compression_percentage": result["compression_percentage"],
"receipt_id": receipt["receipt_id"],
"receipt": receipt,
})
tokens_saved_total = max(0, original_tokens - output_tokens)
compression_pct = (
round(100.0 * tokens_saved_total / original_tokens, 1)
if original_tokens else 0.0
)
decision = _message_decision(
tokens_saved=tokens_saved_total,
receipts=nested_receipts,
)
aggregate_receipt = {
"receipt_version": "message-compression-receipt-v0.1.0",
"receipt_id": _aggregate_receipt_id({
"input_sha256": _stable_json_sha256(messages),
"output_sha256": _stable_json_sha256(output_messages),
"receipt_ids": receipt_ids,
"tokens_saved": tokens_saved_total,
"compression_percentage": compression_pct,
"decision": decision,
}),
"request_id": request_id,
"idempotency_key": idempotency_key,
"message_count": len(messages),
"compressed_message_count": compressed_message_count,
"skipped_message_count": skipped_message_count,
"protected_roles": sorted(protected_roles),
"compress_roles": sorted(compress_roles) if compress_roles is not None else None,
"decision": decision,
"deletion_only": all(
receipt.get("deletion_only", True) for receipt in nested_receipts
),
"deterministic": True,
"input_sha256": _stable_json_sha256(messages),
"output_sha256": _stable_json_sha256(output_messages),
"message_receipt_ids": receipt_ids,
}
compressed_prompt: dict[str, Any] = {
"messages": output_messages,
"protected_spans": protected_values,
}
if "tools" in payload:
compressed_prompt["tools"] = payload["tools"]
if "tool_schemas" in payload:
compressed_prompt["tool_schemas"] = payload["tool_schemas"]
body = {
"schema_version": API_SCHEMA_VERSION,
"status": "ok",
"endpoint": "/v1/compress",
"maturity": "measurement_only",
"messages": output_messages,
"compressed_messages": output_messages,
"compressed_prompts": [compressed_prompt],
"message_count": len(messages),
"compressed_message_count": compressed_message_count,
"skipped_message_count": skipped_message_count,
"original_input_tokens": original_tokens,
"output_tokens": output_tokens,
"tokens_saved": tokens_saved_total,
"compression_percentage": compression_pct,
"receipt_ids": receipt_ids,
"receipts": receipts,
"receipt": aggregate_receipt,
}
if request_id is not None:
body["request_id"] = request_id
if idempotency_key is not None:
body["idempotency_key"] = idempotency_key
return body
def _dispatch_compress(payload: dict[str, Any]) -> dict[str, Any]:
if "inputs" in payload:
return _handle_batch(payload)
if "messages" in payload:
return _handle_messages(payload)
return _compress_text(payload)
def _tokens(text: str) -> list[dict[str, Any]]:
started = time.perf_counter()
try:
tokenizer = _get_tokenizer()
encoded = tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
)
pieces = tokenizer.convert_ids_to_tokens(encoded["input_ids"])
offsets = encoded["offset_mapping"]
token_source = "deberta_tokenizer"
except Exception:
matches = list(re.finditer(r"\S+|\s+", text))
pieces = [match.group(0) for match in matches]
offsets = [(match.start(), match.end()) for match in matches]
token_source = "regex_fallback"
elapsed = round((time.perf_counter() - started) * 1000.0, 3)
return [
{
"token": piece,
"label": "KEEP",
"score": 1.0,
"start": int(offsets[i][0]),
"end": int(offsets[i][1]),
"source": token_source,
"classifier_latency_ms": elapsed if i == 0 else 0.0,
}
for i, piece in enumerate(pieces)
]
@app.get("/health")
def health() -> dict[str, Any]:
classifier_status, _, classifier_error = _manifest_drop_labels("")
return {
"status": "ok",
"classifier_model": CLASSIFIER_MODEL,
"classifier_status": classifier_status,
"classifier_artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR),
"classifier_error": classifier_error,
"phase": "rules_api_first",
"transport": {
"request_id_header": REQUEST_ID_HEADER,
"idempotency_key_header": IDEMPOTENCY_KEY_HEADER,
"content_encoding_fallback_header": TOUCHDOWN_CONTENT_ENCODING_HEADER,
"gzip_request_bodies": True,
"gzip_responses": True,
"idempotency_mode": "in_memory_response_replay",
"idempotency_ttl_seconds": _idempotency_ttl_seconds(),
},
}
@app.post("/v1/classify")
def classify(payload: dict[str, Any]) -> dict[str, Any]:
text = payload.get("input")
if not isinstance(text, str):
raise HTTPException(status_code=400, detail="input must be a string")
classifier_status, artifact_labels, classifier_error = _manifest_drop_labels(text)
labels = artifact_labels or _tokens(text)
return {
"model": CLASSIFIER_MODEL,
"task": "token_classification_keep_drop",
"status": (
classifier_status if CLASSIFIER_ARTIFACT_DIR or artifact_labels or classifier_error
else "tokenizer_only_until_trained_head"
),
"artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR),
"error": classifier_error,
"labels": labels,
"note": (
"This free Space scaffold loads the DeBERTa tokenizer and returns "
"KEEP-only labels until a trained KEEP/DROP head or ONNX export is "
"mounted. Do not treat it as a production compressor."
),
}
@app.post("/v1/compress")
async def compress(request: Request) -> dict[str, Any]:
raw = await request.body()
try:
body, _ = _decode_request_body(
raw,
request.headers.get(CONTENT_ENCODING_HEADER)
or request.headers.get(TOUCHDOWN_CONTENT_ENCODING_HEADER),
)
except UnsupportedContentEncoding as exc:
raise HTTPException(status_code=415, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
try:
payload = json.loads(body.decode("utf-8")) if body else {}
except Exception as exc:
raise HTTPException(status_code=400, detail="invalid JSON body") from exc
if not isinstance(payload, dict):
raise HTTPException(status_code=400, detail="request body must be a JSON object")
payload = _correlation_payload(
payload,
request_id=getattr(request.state, "request_id", None),
idempotency_key=request.headers.get(IDEMPOTENCY_KEY_HEADER),
)
return _handle_compress_with_idempotency(payload)