from __future__ import annotations import copy import gzip import hashlib import json import math import os import re import threading import time import uuid from functools import lru_cache from pathlib import Path from typing import Any from fastapi import FastAPI, HTTPException, Request, Response CLASSIFIER_MODEL = "microsoft/deberta-v3-small" CLASSIFIER_ARTIFACT_DIR = os.environ.get("TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR") API_SCHEMA_VERSION = "0.1.0" RULES_VERSION = "hf-space-rules-v0.1.0" REQUEST_ID_HEADER = "X-Request-ID" IDEMPOTENCY_KEY_HEADER = "Idempotency-Key" CONTENT_ENCODING_HEADER = "Content-Encoding" TOUCHDOWN_CONTENT_ENCODING_HEADER = "X-Touchdown-Content-Encoding" ACCEPT_ENCODING_HEADER = "Accept-Encoding" GZIP_ENCODING = "gzip" GZIP_MAGIC = b"\x1f\x8b" DEFAULT_IDEMPOTENCY_TTL_SECONDS = 24 * 60 * 60 IDEMPOTENCY_TTL_ENV = "TOUCHDOWN_IDEMPOTENCY_TTL_SECONDS" DEFAULT_PROTECTED_MESSAGE_ROLES = ("system", "developer") LOW_SIGNAL_PATTERNS = [ re.compile(pattern, re.IGNORECASE) for pattern in [ r"\blet me know if you (need|want|have)\b", r"\bi hope this helps\b", r"\bas an ai language model\b", r"\bof course[,!]?\s*$", r"\bcertainly[,!]?\s*$", r"\bsure[,!]?\s*$", ] ] app = FastAPI(title="Touchdown Compression Classifier", version="0.1.0") @app.middleware("http") async def transport_middleware(request: Request, call_next): request_id = request.headers.get(REQUEST_ID_HEADER) or f"tdreq_{uuid.uuid4().hex}" try: body, request_was_gzip = _decode_request_body( await request.body(), request.headers.get(CONTENT_ENCODING_HEADER) or request.headers.get(TOUCHDOWN_CONTENT_ENCODING_HEADER), ) except UnsupportedContentEncoding as exc: response = Response( content=json.dumps({ "status": "error", "error": str(exc), "request_id": request_id, }), status_code=415, media_type="application/json", headers={REQUEST_ID_HEADER: request_id}, ) return await _gzip_response_if_needed( request, response, request_was_gzip=False, ) except ValueError as exc: response = Response( content=json.dumps({ "status": "error", "error": str(exc), "request_id": request_id, }), status_code=400, media_type="application/json", headers={REQUEST_ID_HEADER: request_id}, ) return await _gzip_response_if_needed( request, response, request_was_gzip=False, ) body_sent = False async def receive() -> dict[str, Any]: nonlocal body_sent if body_sent: return {"type": "http.request", "body": b"", "more_body": False} body_sent = True return {"type": "http.request", "body": body, "more_body": False} scope = dict(request.scope) scope["headers"] = [ ( key, str(len(body)).encode("latin-1") if key.lower() == b"content-length" else value, ) for key, value in request.scope.get("headers", []) if key.lower() not in (b"content-encoding", b"x-touchdown-content-encoding") ] request = Request(scope, receive) request.state.request_id = request_id response = await call_next(request) response.headers[REQUEST_ID_HEADER] = request_id return await _gzip_response_if_needed( request, response, request_was_gzip=request_was_gzip, ) class UnsupportedContentEncoding(ValueError): """Raised when an HTTP request uses an unsupported content encoding.""" _IDEMPOTENCY_LOCK = threading.Lock() _IDEMPOTENCY_CACHE: dict[tuple[str, str, str], dict[str, Any]] = {} def _content_encoding(value: str | None) -> str: return (value or "identity").split(";", 1)[0].strip().lower() or "identity" def _decode_request_body(raw: bytes, content_encoding: str | None) -> tuple[bytes, bool]: encoding = _content_encoding(content_encoding) if encoding == "identity": if raw.startswith(GZIP_MAGIC): try: return gzip.decompress(raw), True except OSError as exc: raise ValueError("invalid gzip request body") from exc return raw, False if encoding != GZIP_ENCODING: raise UnsupportedContentEncoding( f"unsupported content encoding: {content_encoding}" ) try: return gzip.decompress(raw), True except OSError as exc: raise ValueError("invalid gzip request body") from exc def _accepts_gzip(accept_encoding: str | None) -> bool: for part in (accept_encoding or "").split(","): normalized = part.replace(" ", "").lower() token = normalized.split(";", 1)[0] if token in (GZIP_ENCODING, "*") and "q=0" not in normalized: return True return False def _should_gzip_response( *, accept_encoding: str | None, request_was_gzip: bool, ) -> bool: return request_was_gzip or _accepts_gzip(accept_encoding) def _idempotency_ttl_seconds() -> int: raw = os.environ.get(IDEMPOTENCY_TTL_ENV) if raw is None: return DEFAULT_IDEMPOTENCY_TTL_SECONDS try: parsed = int(raw) except ValueError: return DEFAULT_IDEMPOTENCY_TTL_SECONDS return max(0, parsed) def _idempotency_key_from_payload(payload: dict[str, Any]) -> str | None: value = payload.get("idempotency_key") return value if isinstance(value, str) and value else None def _fingerprint_value(value: Any) -> Any: if isinstance(value, dict): return { key: _fingerprint_value(item) for key, item in sorted(value.items()) if key not in {"request_id", "idempotency_key"} } if isinstance(value, list): return [_fingerprint_value(item) for item in value] return value def _idempotency_fingerprint(route: str, payload: dict[str, Any]) -> str: encoded = json.dumps( { "route": route, "payload": _fingerprint_value(payload), }, ensure_ascii=False, sort_keys=True, separators=(",", ":"), ) return hashlib.sha256(encoded.encode("utf-8")).hexdigest() def _walk_receipts(value: Any) -> list[dict[str, Any]]: receipts: list[dict[str, Any]] = [] if isinstance(value, dict): receipt = value.get("receipt") if isinstance(receipt, dict): receipts.append(receipt) for item in value.values(): receipts.extend(_walk_receipts(item)) elif isinstance(value, list): for item in value: receipts.extend(_walk_receipts(item)) return receipts def _mark_idempotency( body: dict[str, Any], *, key: str, fingerprint: str, replayed: bool, original_request_id: str | None, ) -> dict[str, Any]: ttl_seconds = _idempotency_ttl_seconds() marked = copy.deepcopy(body) marked["idempotency_key"] = key marked["idempotency_mode"] = "in_memory_response_replay" marked["idempotency_replayed"] = replayed marked["idempotency_fingerprint"] = fingerprint marked["idempotency_ttl_seconds"] = ttl_seconds if original_request_id: marked["idempotency_original_request_id"] = original_request_id for receipt in _walk_receipts(marked): if receipt.get("idempotency_key") in (None, key): receipt["idempotency_key"] = key receipt["idempotency_mode"] = "in_memory_response_replay" receipt["idempotency_replayed"] = replayed receipt["idempotency_fingerprint"] = fingerprint receipt["idempotency_ttl_seconds"] = ttl_seconds if original_request_id: receipt["idempotency_original_request_id"] = original_request_id return marked def _cached_idempotency_body( *, route: str, key: str, fingerprint: str, ) -> dict[str, Any] | None: now = time.monotonic() with _IDEMPOTENCY_LOCK: expired = [ cache_key for cache_key, record in _IDEMPOTENCY_CACHE.items() if float(record["expires_at"]) <= now ] for cache_key in expired: _IDEMPOTENCY_CACHE.pop(cache_key, None) prefix = (route, key) for cache_key, record in _IDEMPOTENCY_CACHE.items(): if cache_key[:2] != prefix: continue if record["fingerprint"] != fingerprint: raise HTTPException(status_code=409, detail={ "schema_version": API_SCHEMA_VERSION, "status": "error", "error": "idempotency_key_reused_with_different_payload", "idempotency_key": key, "idempotency_fingerprint": fingerprint, "idempotency_existing_fingerprint": record["fingerprint"], }) return _mark_idempotency( record["body"], key=key, fingerprint=fingerprint, replayed=True, original_request_id=record.get("request_id"), ) return None def _store_idempotency_body( *, route: str, key: str, fingerprint: str, body: dict[str, Any], request_id: str | None, ) -> dict[str, Any]: marked = _mark_idempotency( body, key=key, fingerprint=fingerprint, replayed=False, original_request_id=request_id, ) ttl_seconds = _idempotency_ttl_seconds() if ttl_seconds > 0: with _IDEMPOTENCY_LOCK: _IDEMPOTENCY_CACHE[(route, key, fingerprint)] = { "fingerprint": fingerprint, "body": marked, "request_id": request_id, "expires_at": time.monotonic() + ttl_seconds, } return marked def _handle_compress_with_idempotency(payload: dict[str, Any]) -> dict[str, Any]: key = _idempotency_key_from_payload(payload) if not key: return _dispatch_compress(payload) route = "/v1/compress" fingerprint = _idempotency_fingerprint(route, payload) cached = _cached_idempotency_body( route=route, key=key, fingerprint=fingerprint, ) if cached is not None: return cached body = _dispatch_compress(payload) request_id = payload.get("request_id") if isinstance(payload.get("request_id"), str) else None return _store_idempotency_body( route=route, key=key, fingerprint=fingerprint, body=body, request_id=request_id, ) async def _gzip_response_if_needed( request: Request, response: Response, *, request_was_gzip: bool, ) -> Response: should_gzip = _should_gzip_response( accept_encoding=request.headers.get(ACCEPT_ENCODING_HEADER), request_was_gzip=request_was_gzip, ) if not should_gzip or response.headers.get(CONTENT_ENCODING_HEADER): return response if hasattr(response, "body_iterator"): body = b"" async for chunk in response.body_iterator: body += chunk else: body = getattr(response, "body", b"") compressed = gzip.compress(body) headers = { key: value for key, value in response.headers.items() if key.lower() not in ("content-length", "content-encoding") } headers[CONTENT_ENCODING_HEADER] = GZIP_ENCODING headers["Vary"] = ACCEPT_ENCODING_HEADER headers["Content-Length"] = str(len(compressed)) return Response( content=compressed, status_code=response.status_code, headers=headers, background=getattr(response, "background", None), ) @lru_cache(maxsize=1) def _get_tokenizer(): from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(CLASSIFIER_MODEL) @lru_cache(maxsize=8) def _get_count_tokenizer(model_name: str): from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(model_name) @lru_cache(maxsize=1) def _classifier_manifest() -> dict[str, Any] | None: if not CLASSIFIER_ARTIFACT_DIR: return None path = Path(CLASSIFIER_ARTIFACT_DIR) / "classifier_manifest.json" if not path.exists(): return { "status": "manifest_error", "error": f"missing classifier manifest: {path}", } try: payload = json.loads(path.read_text(encoding="utf-8")) except Exception as exc: return {"status": "manifest_error", "error": str(exc)} if not isinstance(payload, dict): return {"status": "manifest_error", "error": "manifest must be an object"} return payload @lru_cache(maxsize=1) def _get_artifact_tokenizer(): from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(str(CLASSIFIER_ARTIFACT_DIR)) @lru_cache(maxsize=1) def _get_onnx_session(model_path: str): import onnxruntime as ort return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) def _find_spans(text: str, needle: str) -> list[tuple[int, int]]: spans = [] cursor = 0 while needle: index = text.find(needle, cursor) if index < 0: return spans spans.append((index, index + len(needle))) cursor = index + len(needle) return spans def _regex_spans(text: str, pattern: str, flags: int) -> list[tuple[int, int]]: return [(match.start(), match.end()) for match in re.finditer(pattern, text, flags)] def _json_line_spans(text: str) -> list[tuple[int, int]]: spans = [] cursor = 0 for line in text.splitlines(keepends=True): stripped = line.strip() if stripped and stripped[0] in "[{": try: import json json.loads(stripped) except Exception: pass else: offset = line.index(stripped[0]) spans.append((cursor + offset, cursor + offset + len(stripped))) cursor += len(line) return spans def _merge(spans: list[tuple[int, int]]) -> list[tuple[int, int]]: if not spans: return [] ordered = sorted(spans) merged = [ordered[0]] for start, end in ordered[1:]: prev_start, prev_end = merged[-1] if start <= prev_end: merged[-1] = (prev_start, max(prev_end, end)) else: merged.append((start, end)) return merged def _overlaps(spans: list[tuple[int, int]], start: int, end: int) -> bool: return any(start < span_end and end > span_start for span_start, span_end in spans) def _manifest_drop_labels(text: str) -> tuple[str, list[dict[str, Any]], str | None]: manifest = _classifier_manifest() if not manifest: return "tokenizer_loaded_no_trained_keep_drop_head", [], None if manifest.get("status") == "manifest_error": return "manifest_error", [], str(manifest.get("error")) onnx_model = Path(CLASSIFIER_ARTIFACT_DIR or "") / str( manifest.get("onnx_model", "model.onnx") ) if onnx_model.exists(): try: return ( str(manifest.get("status") or "trained_keep_drop_head"), _onnx_labels(text, manifest, onnx_model), None, ) except Exception as exc: return "local_artifact_error", [], str(exc) labels: list[dict[str, Any]] = [] for index, entry in enumerate(manifest.get("drop_phrases") or []): if isinstance(entry, str): phrase = entry score = 1.0 elif isinstance(entry, dict): phrase = str(entry.get("text") or "") score = float(entry.get("score", 1.0)) else: continue cursor = 0 while phrase: start = text.find(phrase, cursor) if start < 0: break end = start + len(phrase) labels.append({ "token": phrase, "label": "DROP", "score": score, "start": start, "end": end, "source": "manifest_drop_phrase", "classifier_index": index, }) cursor = end return str(manifest.get("status") or "artifact_manifest_loaded"), labels, None def _softmax(values: list[float]) -> list[float]: peak = max(values) exps = [math.exp(value - peak) for value in values] total = sum(exps) or 1.0 return [value / total for value in exps] def _onnx_labels( text: str, manifest: dict[str, Any], model_path: Path, ) -> list[dict[str, Any]]: import numpy as np tokenizer = _get_artifact_tokenizer() encoded = tokenizer( text, add_special_tokens=False, return_offsets_mapping=True, return_tensors="np", truncation=True, max_length=int(manifest.get("max_length", 512)), stride=int(manifest.get("stride", 0)), return_overflowing_tokens=True, padding=True, ) offsets = encoded.pop("offset_mapping") input_ids = encoded["input_ids"] session = _get_onnx_session(str(model_path)) input_names = {item.name for item in session.get_inputs()} inputs = {key: value for key, value in encoded.items() if key in input_names} logits = session.run(None, inputs)[0] id2label = { str(key): value for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items() } label2id = {str(value).upper(): int(key) for key, value in id2label.items()} keep_id = label2id.get("KEEP", 0) drop_id = label2id.get("DROP", 1) drop_score_threshold = float(manifest.get("drop_score_threshold", 0.5)) best_by_span: dict[tuple[int, int], dict[str, Any]] = {} for chunk_index, chunk_logits in enumerate(logits): tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index]) for token_index, token_logits in enumerate(chunk_logits): probs = _softmax(np.asarray(token_logits, dtype=float).tolist()) keep_score = float(probs[keep_id]) if keep_id < len(probs) else 0.0 drop_score = float(probs[drop_id]) if drop_id < len(probs) else 0.0 label_id = drop_id if drop_score >= drop_score_threshold else keep_id start, end = offsets[chunk_index][token_index] start = int(start) end = int(end) if end <= start: continue key = (start, end) item = { "token": tokens[token_index], "label": id2label.get(str(label_id), "KEEP"), "score": round(drop_score if label_id == drop_id else keep_score, 6), "keep_score": round(keep_score, 6), "drop_score": round(drop_score, 6), "drop_score_threshold": drop_score_threshold, "start": start, "end": end, "source": "onnx_token_classifier", "chunk_index": chunk_index, } if drop_score > float(best_by_span.get(key, {}).get("drop_score", -1.0)): best_by_span[key] = item return [best_by_span[key] for key in sorted(best_by_span)] def _safe_classifier_drop_ranges( labels: list[dict[str, Any]], protected: list[tuple[int, int]], *, text_length: int, min_score: float, ) -> tuple[list[tuple[int, int]], int, int, dict[str, int]]: blocked_by_reason = { "malformed_or_missing_offsets": 0, "out_of_bounds": 0, "below_min_score": 0, "protected_span_overlap": 0, } ranges: list[tuple[int, int]] = [] drop_labels = 0 for item in labels: raw_drop_score = item.get("drop_score") try: drop_score = float(raw_drop_score) if raw_drop_score is not None else None except (TypeError, ValueError): drop_score = None if ( str(item.get("label") or item.get("entity") or "").upper() != "DROP" and (drop_score is None or drop_score < min_score) ): continue drop_labels += 1 try: start = int(item["start"]) end = int(item["end"]) score = drop_score if drop_score is not None else float(item.get("score", 1.0)) except Exception: blocked_by_reason["malformed_or_missing_offsets"] += 1 continue if start < 0 or end > text_length or end <= start: blocked_by_reason["out_of_bounds"] += 1 continue if score < min_score: blocked_by_reason["below_min_score"] += 1 continue if _overlaps(protected, start, end): blocked_by_reason["protected_span_overlap"] += 1 continue ranges.append((start, end)) merged = _merge(ranges) return merged, drop_labels, len(merged), blocked_by_reason def _is_subsequence(candidate: str, original: str) -> bool: cursor = 0 for char in candidate: cursor = original.find(char, cursor) if cursor < 0: return False cursor += 1 return True def _sha256_text(value: str) -> str: return hashlib.sha256(value.encode("utf-8")).hexdigest() def _receipt_id(payload: dict[str, Any]) -> str: encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) return "tdcr_" + hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:24] def _stable_json_sha256(value: Any) -> str: encoded = json.dumps( value, ensure_ascii=False, sort_keys=True, separators=(",", ":"), ) return hashlib.sha256(encoded.encode("utf-8")).hexdigest() def _aggregate_receipt_id(payload: dict[str, Any]) -> str: encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) return "tdcm_" + hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:24] def _string_set( value: Any, *, default: tuple[str, ...] = (), field_name: str, ) -> set[str]: if value is None: return set(default) if ( isinstance(value, list) and all(isinstance(item, str) and item for item in value) ): return {item.lower() for item in value} raise HTTPException(status_code=400, detail=f"{field_name} must be a list of strings") def _message_role(message: dict[str, Any], index: int) -> str: role = message.get("role") if not isinstance(role, str) or not role: raise HTTPException( status_code=400, detail=f"messages[{index}].role must be a string", ) return role.lower() def _message_decision( *, tokens_saved: int, receipts: list[dict[str, Any]], ) -> str: decisions = [ receipt.get("decision") for receipt in receipts if isinstance(receipt, dict) ] if any(decision == "reject" for decision in decisions): return "reject" if any(decision == "needs_review" for decision in decisions): return "needs_review" if tokens_saved <= 0: return "no_op" return "high_confidence" def _correlation_payload( payload: dict[str, Any], *, request_id: str | None = None, idempotency_key: str | None = None, ) -> dict[str, Any]: out = dict(payload) if request_id and "request_id" not in out: out["request_id"] = request_id if idempotency_key and "idempotency_key" not in out: out["idempotency_key"] = idempotency_key return out def _tool_schema_variants(schema: Any) -> list[str]: if isinstance(schema, str): return [schema] if schema else [] if not isinstance(schema, (dict, list)): raise HTTPException( status_code=400, detail="tool_schemas entries must be strings, objects, or arrays", ) variants: list[str] = [] for kwargs in ( {"ensure_ascii": False, "sort_keys": False, "separators": (",", ":")}, {"ensure_ascii": False, "sort_keys": True, "separators": (",", ":")}, {"ensure_ascii": False, "sort_keys": False, "indent": 2}, {"ensure_ascii": False, "sort_keys": True, "indent": 2}, ): encoded = json.dumps(schema, **kwargs) if encoded not in variants: variants.append(encoded) return variants def _tool_schema_groups(tool_schemas: Any) -> list[list[str]]: if tool_schemas is None: return [] if not isinstance(tool_schemas, list): raise HTTPException(status_code=400, detail="tool_schemas must be a list") return [_tool_schema_variants(schema) for schema in tool_schemas] def _schema_occurrences(text: str, group: list[str]) -> int: return sum(len(_find_spans(text, candidate)) for candidate in group) def _tool_schema_missing_groups( input_text: str, output: str, groups: list[list[str]], ) -> list[list[str]]: missing = [] for group in groups: input_occurrences = _schema_occurrences(input_text, group) output_occurrences = _schema_occurrences(output, group) if input_occurrences > 0 and output_occurrences == 0: missing.append(group) return missing def _optional_string(payload: dict[str, Any], key: str) -> str | None: value = payload.get(key) if value is None: return None if not isinstance(value, str): raise HTTPException(status_code=400, detail=f"{key} must be a string") return value def _count_tokens(text: str, tokenizer_model: str | None) -> tuple[int, bool, str]: if tokenizer_model: try: tokenizer = _get_count_tokenizer(tokenizer_model) return ( len(tokenizer.encode(text, add_special_tokens=False)), True, "tokenizer", ) except Exception: pass return max(1, round(len(text) / 4.0)), False, "chars_per_token_estimate" def _protected_spans( text: str, protected_values: list[str], tool_schema_groups: list[list[str]], *, preserve_code: bool, preserve_json: bool, ) -> tuple[ list[tuple[int, int]], list[tuple[int, int]], list[tuple[int, int]], list[tuple[int, int]], list[tuple[int, int]], ]: spans = [] for value in protected_values: spans.extend(_find_spans(text, value)) tool_schema_spans = [] for group in tool_schema_groups: for candidate in group: tool_schema_spans.extend(_find_spans(text, candidate)) spans.extend(tool_schema_spans) code_spans = ( _regex_spans(text, r"```.*?```", re.MULTILINE | re.DOTALL) if preserve_code else [] ) json_spans = ( _regex_spans(text, r"```(?:json|JSON)\s*\n.*?```", re.MULTILINE | re.DOTALL) + _json_line_spans(text) if preserve_json else [] ) system_spans = ( _regex_spans(text, r"]*>.*?", re.MULTILINE | re.DOTALL) + _regex_spans(text, r"^system\s*:\s*.*$", re.MULTILINE | re.IGNORECASE) ) spans.extend(code_spans) spans.extend(json_spans) spans.extend(system_spans) return ( _merge(spans), code_spans, json_spans, system_spans, _merge(tool_schema_spans), ) def _compress_text(payload: dict[str, Any]) -> dict[str, Any]: text = payload.get("input") if not isinstance(text, str): raise HTTPException(status_code=400, detail="input must be a string") settings = payload.get("compression_settings") or {} if not isinstance(settings, dict): raise HTTPException(status_code=400, detail="compression_settings must be an object") request_id = payload.get("request_id") if request_id is not None and not isinstance(request_id, str): raise HTTPException(status_code=400, detail="request_id must be a string") idempotency_key = payload.get("idempotency_key") if idempotency_key is not None and not isinstance(idempotency_key, str): raise HTTPException(status_code=400, detail="idempotency_key must be a string") tokenizer_model = _optional_string(payload, "tokenizer_model") protected_values = payload.get("protected_spans") or [] if not isinstance(protected_values, list) or not all( isinstance(value, str) for value in protected_values ): raise HTTPException(status_code=400, detail="protected_spans must be strings") tool_schema_groups = _tool_schema_groups( payload.get("tool_schemas", payload.get("tools")) ) try: aggressiveness = float(settings.get("aggressiveness", 0.35)) except Exception as exc: raise HTTPException(status_code=400, detail="aggressiveness must be numeric") from exc if not 0.0 <= aggressiveness <= 1.0: raise HTTPException(status_code=400, detail="aggressiveness must be 0..1") started = time.perf_counter() preserve_code = bool(settings.get("preserve_code", True)) preserve_json = bool(settings.get("preserve_json", True)) protected, code_spans, json_spans, system_spans, tool_schema_spans = _protected_spans( text, protected_values, tool_schema_groups, preserve_code=preserve_code, preserve_json=preserve_json, ) classifier_status, classifier_labels, classifier_error = _manifest_drop_labels(text) classifier_ranges, classifier_drop_labels, classifier_applied, classifier_blocked_by_reason = ( _safe_classifier_drop_ranges( classifier_labels, protected, text_length=len(text), min_score=float(settings.get("classifier_min_drop_score", 0.5)), ) ) seen: set[str] = set() drops: list[tuple[int, int, str, str]] = [] cursor = 0 previous_blank_kept = False for line in text.splitlines(keepends=True): start = cursor end = cursor + len(line) cursor = end stripped = line.strip() normalized = re.sub(r"\s+", " ", stripped).lower() if _overlaps(protected, start, end): if normalized: seen.add(normalized) previous_blank_kept = False continue reason = None if not stripped and previous_blank_kept: reason = "extra_blank" elif len(normalized) > 30 and normalized in seen and aggressiveness >= 0.10: reason = "duplicate_line" elif aggressiveness >= 0.20 and any( pattern.search(stripped) for pattern in LOW_SIGNAL_PATTERNS ): reason = "low_signal_phrase" if reason: drops.append((start, end, reason, stripped[:120])) continue previous_blank_kept = not stripped if normalized: seen.add(normalized) chunks = [] drop_ranges = _merge([(start, end) for start, end, _, _ in drops] + classifier_ranges) cursor = 0 for start, end in drop_ranges: chunks.append(text[cursor:start]) cursor = end chunks.append(text[cursor:]) output = "".join(chunks) before, before_exact, token_method = _count_tokens(text, tokenizer_model) after, after_exact, after_method = _count_tokens(output, tokenizer_model) token_count_exact = before_exact and after_exact if after_method != token_method: token_method = "chars_per_token_estimate" saved = max(0, before - after) missing = [value for value in protected_values if value and value not in output] code_preserved = all(text[start:end] in output for start, end in code_spans) json_preserved = all(text[start:end] in output for start, end in json_spans) system_preserved = all(text[start:end] in output for start, end in system_spans) missing_tool_schemas = _tool_schema_missing_groups( text, output, tool_schema_groups, ) tool_schemas_preserved = not missing_tool_schemas decision = "reject" if ( missing or not code_preserved or not json_preserved or not system_preserved or not tool_schemas_preserved ) else ( "high_confidence" if saved > 0 and aggressiveness <= 0.65 else "no_op" ) dropped_segments = [ {"reason": reason, "preview": preview, "start": start, "end": end} for start, end, reason, preview in drops[:20] ] classifier_blocked = sum(classifier_blocked_by_reason.values()) receipt = { "request_id": request_id, "idempotency_key": idempotency_key, "idempotency_mode": "correlation_only_no_response_replay", "protected_spans_checked": len(protected_values), "protected_spans_missing": len(missing), "code_blocks_detected": len(code_spans), "code_blocks_preserved": code_preserved, "json_blocks_detected": len(json_spans), "json_blocks_preserved": json_preserved, "system_prompt_spans_detected": len(system_spans), "system_prompts_preserved": system_preserved, "tool_schemas_checked": len(tool_schema_groups), "tool_schema_occurrences": sum( _schema_occurrences(text, group) for group in tool_schema_groups ), "tool_schema_spans_detected": len(tool_schema_spans), "tool_schemas_missing": len(missing_tool_schemas), "tool_schemas_missing_preview": [ next((candidate[:120] for candidate in group if candidate), "") for group in missing_tool_schemas ], "tool_schemas_preserved": tool_schemas_preserved, "decision": decision, "compressor_latency_ms": round((time.perf_counter() - started) * 1000.0, 3), "deletion_only": _is_subsequence(output, text), "deterministic": True, "rules_version": RULES_VERSION, "classifier": { "model": CLASSIFIER_MODEL, "status": classifier_status, "artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR), "artifact_dir": CLASSIFIER_ARTIFACT_DIR, "error": classifier_error, "labels_received": len(classifier_labels), "drop_labels": classifier_drop_labels, "drop_spans_applied": classifier_applied, "drop_spans_blocked_by_safety": classifier_blocked, "drop_spans_blocked_by_reason": classifier_blocked_by_reason, }, "removed_spans_count": len(drop_ranges), "removed_chars": sum(end - start for start, end in drop_ranges), "rule_drop_spans_applied": len(drops), "rule_drop_chars": sum(end - start for start, end, _, _ in drops), "classifier_drop_spans_applied": len(classifier_ranges), "classifier_drop_chars": sum(end - start for start, end in classifier_ranges), "dropped_segments_count": len(drops), "dropped_segments": dropped_segments, "token_count_exact": token_count_exact, "token_count_method": token_method, "tokenizer_model": tokenizer_model, } receipt["input_sha256"] = _sha256_text(text) receipt["output_sha256"] = _sha256_text(output) receipt["removed_sha256"] = _sha256_text( "".join(text[start:end] for start, end in drop_ranges) ) receipt["receipt_id"] = _receipt_id({ "input_sha256": receipt["input_sha256"], "output_sha256": receipt["output_sha256"], "removed_sha256": receipt["removed_sha256"], "tokens_saved": saved, "compression_percentage": round(100.0 * saved / before, 1), "decision": decision, "rules_version": RULES_VERSION, "classifier": receipt["classifier"], "dropped_segments": dropped_segments, }) return { "schema_version": API_SCHEMA_VERSION, "status": "ok", "endpoint": "/v1/compress", "maturity": "measurement_only", "output": output, "original_input_tokens": before, "output_tokens": after, "tokens_saved": saved, "compression_percentage": round(100.0 * saved / before, 1), "request_id": request_id, "idempotency_key": idempotency_key, "receipt": receipt, } def _merge_batch_item_payload( payload: dict[str, Any], item: Any, index: int, ) -> tuple[str | None, dict[str, Any]]: if isinstance(item, str): return None, { "input": item, "compression_settings": payload.get("compression_settings"), "protected_spans": payload.get("protected_spans"), "tool_schemas": payload.get("tool_schemas", payload.get("tools")), "tokenizer_model": payload.get("tokenizer_model"), "request_id": payload.get("request_id"), "idempotency_key": payload.get("idempotency_key"), } if not isinstance(item, dict): raise ValueError(f"inputs[{index}] must be a string or an object") item_id = item.get("id") if item_id is not None and not isinstance(item_id, str): raise ValueError(f"inputs[{index}].id must be a string") if "input" not in item: raise ValueError(f"inputs[{index}].input is required") top_settings = payload.get("compression_settings") item_settings = item.get("compression_settings") if top_settings is not None and not isinstance(top_settings, dict): raise ValueError("compression_settings must be an object") if item_settings is not None and not isinstance(item_settings, dict): raise ValueError(f"inputs[{index}].compression_settings must be an object") settings = { **(top_settings or {}), **(item_settings or {}), } or None return item_id, { "input": item.get("input"), "compression_settings": settings, "protected_spans": item.get("protected_spans", payload.get("protected_spans")), "tool_schemas": item.get( "tool_schemas", item.get("tools", payload.get("tool_schemas", payload.get("tools"))), ), "tokenizer_model": item.get("tokenizer_model", payload.get("tokenizer_model")), "request_id": item.get("request_id", payload.get("request_id")), "idempotency_key": item.get( "idempotency_key", payload.get("idempotency_key"), ), } def _handle_batch(payload: dict[str, Any]) -> dict[str, Any]: if "input" in payload: raise HTTPException(status_code=400, detail="provide either input or inputs, not both") inputs = payload.get("inputs") if not isinstance(inputs, list): raise HTTPException(status_code=400, detail="inputs must be a list") if not inputs: raise HTTPException(status_code=400, detail="inputs list is empty") request_id = payload.get("request_id") if request_id is not None and not isinstance(request_id, str): raise HTTPException(status_code=400, detail="request_id must be a string") idempotency_key = payload.get("idempotency_key") if idempotency_key is not None and not isinstance(idempotency_key, str): raise HTTPException(status_code=400, detail="idempotency_key must be a string") results: list[dict[str, Any]] = [] totals = { "original_input_tokens": 0, "output_tokens": 0, "tokens_saved": 0, } succeeded = 0 failed = 0 for index, item in enumerate(inputs): item_result: dict[str, Any] = {"index": index} if isinstance(item, dict) and isinstance(item.get("id"), str): item_result["id"] = item["id"] try: item_id, item_payload = _merge_batch_item_payload(payload, item, index) if item_id is not None: item_result["id"] = item_id result = _compress_text(item_payload) except HTTPException as exc: item_result.update({"status": "error", "error": str(exc.detail)}) failed += 1 results.append(item_result) continue except ValueError as exc: item_result.update({"status": "error", "error": str(exc)}) failed += 1 results.append(item_result) continue item_result.update({"status": "ok", **result}) totals["original_input_tokens"] += int(result["original_input_tokens"]) totals["output_tokens"] += int(result["output_tokens"]) totals["tokens_saved"] += int(result["tokens_saved"]) succeeded += 1 results.append(item_result) compression_pct = ( round(100.0 * totals["tokens_saved"] / totals["original_input_tokens"], 1) if totals["original_input_tokens"] else 0.0 ) receipt_ids = [ result["receipt"]["receipt_id"] for result in results if result.get("status") == "ok" and result.get("receipt", {}).get("receipt_id") ] body = { "schema_version": API_SCHEMA_VERSION, "status": "ok" if failed == 0 else "partial_error", "endpoint": "/v1/compress", "maturity": "measurement_only", "input_count": len(inputs), "succeeded": succeeded, "failed": failed, **totals, "compression_percentage": compression_pct, "receipt_ids": receipt_ids, "results": results, } if request_id is not None: body["request_id"] = request_id if idempotency_key is not None: body["idempotency_key"] = idempotency_key return body def _handle_messages(payload: dict[str, Any]) -> dict[str, Any]: if "input" in payload or "inputs" in payload: raise HTTPException(status_code=400, detail="provide either messages, input, or inputs") messages = payload.get("messages") if not isinstance(messages, list) or not messages: raise HTTPException(status_code=400, detail="messages must be a non-empty list") if not all(isinstance(message, dict) for message in messages): raise HTTPException(status_code=400, detail="messages entries must be objects") settings = payload.get("compression_settings") or {} if not isinstance(settings, dict): raise HTTPException(status_code=400, detail="compression_settings must be an object") protected_roles = _string_set( settings.get("protected_roles"), default=DEFAULT_PROTECTED_MESSAGE_ROLES, field_name="compression_settings.protected_roles", ) compress_roles = ( _string_set( settings.get("compress_roles"), field_name="compression_settings.compress_roles", ) if "compress_roles" in settings else None ) protected_values = payload.get("protected_spans") or [] if not isinstance(protected_values, list) or not all( isinstance(value, str) for value in protected_values ): raise HTTPException(status_code=400, detail="protected_spans must be strings") request_id = payload.get("request_id") if request_id is not None and not isinstance(request_id, str): raise HTTPException(status_code=400, detail="request_id must be a string") idempotency_key = payload.get("idempotency_key") if idempotency_key is not None and not isinstance(idempotency_key, str): raise HTTPException(status_code=400, detail="idempotency_key must be a string") tokenizer_model = _optional_string(payload, "tokenizer_model") output_messages: list[dict[str, Any]] = [] receipts: list[dict[str, Any]] = [] nested_receipts: list[dict[str, Any]] = [] receipt_ids: list[str] = [] original_tokens = 0 output_tokens = 0 compressed_message_count = 0 skipped_message_count = 0 for index, message in enumerate(messages): role = _message_role(message, index) content = message.get("content") output_message = dict(message) if not isinstance(content, str): skipped_message_count += 1 receipts.append({ "index": index, "role": role, "status": "skipped", "reason": "non_string_content", "content_type": type(content).__name__, "original_input_tokens": 0, "output_tokens": 0, "tokens_saved": 0, }) output_messages.append(output_message) continue role_protected = role in protected_roles or ( compress_roles is not None and role not in compress_roles ) item_protected = list(protected_values) if role_protected and content: item_protected.append(content) result = _compress_text({ "input": content, "compression_settings": settings, "protected_spans": item_protected, "tool_schemas": payload.get("tool_schemas", payload.get("tools")), "tokenizer_model": tokenizer_model, "request_id": request_id, "idempotency_key": idempotency_key, }) output_message["content"] = result["output"] output_messages.append(output_message) original_tokens += int(result["original_input_tokens"]) output_tokens += int(result["output_tokens"]) tokens_saved = int(result["tokens_saved"]) if tokens_saved > 0: compressed_message_count += 1 receipt = result["receipt"] nested_receipts.append(receipt) receipt_ids.append(receipt["receipt_id"]) receipts.append({ "index": index, "role": role, "status": "ok", "protected_by_role": role_protected, "original_input_tokens": result["original_input_tokens"], "output_tokens": result["output_tokens"], "tokens_saved": tokens_saved, "compression_percentage": result["compression_percentage"], "receipt_id": receipt["receipt_id"], "receipt": receipt, }) tokens_saved_total = max(0, original_tokens - output_tokens) compression_pct = ( round(100.0 * tokens_saved_total / original_tokens, 1) if original_tokens else 0.0 ) decision = _message_decision( tokens_saved=tokens_saved_total, receipts=nested_receipts, ) aggregate_receipt = { "receipt_version": "message-compression-receipt-v0.1.0", "receipt_id": _aggregate_receipt_id({ "input_sha256": _stable_json_sha256(messages), "output_sha256": _stable_json_sha256(output_messages), "receipt_ids": receipt_ids, "tokens_saved": tokens_saved_total, "compression_percentage": compression_pct, "decision": decision, }), "request_id": request_id, "idempotency_key": idempotency_key, "message_count": len(messages), "compressed_message_count": compressed_message_count, "skipped_message_count": skipped_message_count, "protected_roles": sorted(protected_roles), "compress_roles": sorted(compress_roles) if compress_roles is not None else None, "decision": decision, "deletion_only": all( receipt.get("deletion_only", True) for receipt in nested_receipts ), "deterministic": True, "input_sha256": _stable_json_sha256(messages), "output_sha256": _stable_json_sha256(output_messages), "message_receipt_ids": receipt_ids, } compressed_prompt: dict[str, Any] = { "messages": output_messages, "protected_spans": protected_values, } if "tools" in payload: compressed_prompt["tools"] = payload["tools"] if "tool_schemas" in payload: compressed_prompt["tool_schemas"] = payload["tool_schemas"] body = { "schema_version": API_SCHEMA_VERSION, "status": "ok", "endpoint": "/v1/compress", "maturity": "measurement_only", "messages": output_messages, "compressed_messages": output_messages, "compressed_prompts": [compressed_prompt], "message_count": len(messages), "compressed_message_count": compressed_message_count, "skipped_message_count": skipped_message_count, "original_input_tokens": original_tokens, "output_tokens": output_tokens, "tokens_saved": tokens_saved_total, "compression_percentage": compression_pct, "receipt_ids": receipt_ids, "receipts": receipts, "receipt": aggregate_receipt, } if request_id is not None: body["request_id"] = request_id if idempotency_key is not None: body["idempotency_key"] = idempotency_key return body def _dispatch_compress(payload: dict[str, Any]) -> dict[str, Any]: if "inputs" in payload: return _handle_batch(payload) if "messages" in payload: return _handle_messages(payload) return _compress_text(payload) def _tokens(text: str) -> list[dict[str, Any]]: started = time.perf_counter() try: tokenizer = _get_tokenizer() encoded = tokenizer( text, add_special_tokens=False, return_offsets_mapping=True, ) pieces = tokenizer.convert_ids_to_tokens(encoded["input_ids"]) offsets = encoded["offset_mapping"] token_source = "deberta_tokenizer" except Exception: matches = list(re.finditer(r"\S+|\s+", text)) pieces = [match.group(0) for match in matches] offsets = [(match.start(), match.end()) for match in matches] token_source = "regex_fallback" elapsed = round((time.perf_counter() - started) * 1000.0, 3) return [ { "token": piece, "label": "KEEP", "score": 1.0, "start": int(offsets[i][0]), "end": int(offsets[i][1]), "source": token_source, "classifier_latency_ms": elapsed if i == 0 else 0.0, } for i, piece in enumerate(pieces) ] @app.get("/health") def health() -> dict[str, Any]: classifier_status, _, classifier_error = _manifest_drop_labels("") return { "status": "ok", "classifier_model": CLASSIFIER_MODEL, "classifier_status": classifier_status, "classifier_artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR), "classifier_error": classifier_error, "phase": "rules_api_first", "transport": { "request_id_header": REQUEST_ID_HEADER, "idempotency_key_header": IDEMPOTENCY_KEY_HEADER, "content_encoding_fallback_header": TOUCHDOWN_CONTENT_ENCODING_HEADER, "gzip_request_bodies": True, "gzip_responses": True, "idempotency_mode": "in_memory_response_replay", "idempotency_ttl_seconds": _idempotency_ttl_seconds(), }, } @app.post("/v1/classify") def classify(payload: dict[str, Any]) -> dict[str, Any]: text = payload.get("input") if not isinstance(text, str): raise HTTPException(status_code=400, detail="input must be a string") classifier_status, artifact_labels, classifier_error = _manifest_drop_labels(text) labels = artifact_labels or _tokens(text) return { "model": CLASSIFIER_MODEL, "task": "token_classification_keep_drop", "status": ( classifier_status if CLASSIFIER_ARTIFACT_DIR or artifact_labels or classifier_error else "tokenizer_only_until_trained_head" ), "artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR), "error": classifier_error, "labels": labels, "note": ( "This free Space scaffold loads the DeBERTa tokenizer and returns " "KEEP-only labels until a trained KEEP/DROP head or ONNX export is " "mounted. Do not treat it as a production compressor." ), } @app.post("/v1/compress") async def compress(request: Request) -> dict[str, Any]: raw = await request.body() try: body, _ = _decode_request_body( raw, request.headers.get(CONTENT_ENCODING_HEADER) or request.headers.get(TOUCHDOWN_CONTENT_ENCODING_HEADER), ) except UnsupportedContentEncoding as exc: raise HTTPException(status_code=415, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc try: payload = json.loads(body.decode("utf-8")) if body else {} except Exception as exc: raise HTTPException(status_code=400, detail="invalid JSON body") from exc if not isinstance(payload, dict): raise HTTPException(status_code=400, detail="request body must be a JSON object") payload = _correlation_payload( payload, request_id=getattr(request.state, "request_id", None), idempotency_key=request.headers.get(IDEMPOTENCY_KEY_HEADER), ) return _handle_compress_with_idempotency(payload)