| from __future__ import annotations |
|
|
| import copy |
| import gzip |
| import hashlib |
| import json |
| import math |
| import os |
| import re |
| import threading |
| import time |
| import uuid |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Any |
|
|
| from fastapi import FastAPI, HTTPException, Request, Response |
|
|
| CLASSIFIER_MODEL = "microsoft/deberta-v3-small" |
| CLASSIFIER_ARTIFACT_DIR = os.environ.get("TOUCHDOWN_CLASSIFIER_ARTIFACT_DIR") |
| API_SCHEMA_VERSION = "0.1.0" |
| RULES_VERSION = "hf-space-rules-v0.1.0" |
| REQUEST_ID_HEADER = "X-Request-ID" |
| IDEMPOTENCY_KEY_HEADER = "Idempotency-Key" |
| CONTENT_ENCODING_HEADER = "Content-Encoding" |
| TOUCHDOWN_CONTENT_ENCODING_HEADER = "X-Touchdown-Content-Encoding" |
| ACCEPT_ENCODING_HEADER = "Accept-Encoding" |
| GZIP_ENCODING = "gzip" |
| GZIP_MAGIC = b"\x1f\x8b" |
| DEFAULT_IDEMPOTENCY_TTL_SECONDS = 24 * 60 * 60 |
| IDEMPOTENCY_TTL_ENV = "TOUCHDOWN_IDEMPOTENCY_TTL_SECONDS" |
| DEFAULT_PROTECTED_MESSAGE_ROLES = ("system", "developer") |
| LOW_SIGNAL_PATTERNS = [ |
| re.compile(pattern, re.IGNORECASE) |
| for pattern in [ |
| r"\blet me know if you (need|want|have)\b", |
| r"\bi hope this helps\b", |
| r"\bas an ai language model\b", |
| r"\bof course[,!]?\s*$", |
| r"\bcertainly[,!]?\s*$", |
| r"\bsure[,!]?\s*$", |
| ] |
| ] |
|
|
| app = FastAPI(title="Touchdown Compression Classifier", version="0.1.0") |
|
|
|
|
| @app.middleware("http") |
| async def transport_middleware(request: Request, call_next): |
| request_id = request.headers.get(REQUEST_ID_HEADER) or f"tdreq_{uuid.uuid4().hex}" |
| try: |
| body, request_was_gzip = _decode_request_body( |
| await request.body(), |
| request.headers.get(CONTENT_ENCODING_HEADER) |
| or request.headers.get(TOUCHDOWN_CONTENT_ENCODING_HEADER), |
| ) |
| except UnsupportedContentEncoding as exc: |
| response = Response( |
| content=json.dumps({ |
| "status": "error", |
| "error": str(exc), |
| "request_id": request_id, |
| }), |
| status_code=415, |
| media_type="application/json", |
| headers={REQUEST_ID_HEADER: request_id}, |
| ) |
| return await _gzip_response_if_needed( |
| request, |
| response, |
| request_was_gzip=False, |
| ) |
| except ValueError as exc: |
| response = Response( |
| content=json.dumps({ |
| "status": "error", |
| "error": str(exc), |
| "request_id": request_id, |
| }), |
| status_code=400, |
| media_type="application/json", |
| headers={REQUEST_ID_HEADER: request_id}, |
| ) |
| return await _gzip_response_if_needed( |
| request, |
| response, |
| request_was_gzip=False, |
| ) |
|
|
| body_sent = False |
|
|
| async def receive() -> dict[str, Any]: |
| nonlocal body_sent |
| if body_sent: |
| return {"type": "http.request", "body": b"", "more_body": False} |
| body_sent = True |
| return {"type": "http.request", "body": body, "more_body": False} |
|
|
| scope = dict(request.scope) |
| scope["headers"] = [ |
| ( |
| key, |
| str(len(body)).encode("latin-1") |
| if key.lower() == b"content-length" else value, |
| ) |
| for key, value in request.scope.get("headers", []) |
| if key.lower() not in (b"content-encoding", b"x-touchdown-content-encoding") |
| ] |
| request = Request(scope, receive) |
| request.state.request_id = request_id |
| response = await call_next(request) |
| response.headers[REQUEST_ID_HEADER] = request_id |
| return await _gzip_response_if_needed( |
| request, |
| response, |
| request_was_gzip=request_was_gzip, |
| ) |
|
|
|
|
| class UnsupportedContentEncoding(ValueError): |
| """Raised when an HTTP request uses an unsupported content encoding.""" |
|
|
|
|
| _IDEMPOTENCY_LOCK = threading.Lock() |
| _IDEMPOTENCY_CACHE: dict[tuple[str, str, str], dict[str, Any]] = {} |
|
|
|
|
| def _content_encoding(value: str | None) -> str: |
| return (value or "identity").split(";", 1)[0].strip().lower() or "identity" |
|
|
|
|
| def _decode_request_body(raw: bytes, content_encoding: str | None) -> tuple[bytes, bool]: |
| encoding = _content_encoding(content_encoding) |
| if encoding == "identity": |
| if raw.startswith(GZIP_MAGIC): |
| try: |
| return gzip.decompress(raw), True |
| except OSError as exc: |
| raise ValueError("invalid gzip request body") from exc |
| return raw, False |
| if encoding != GZIP_ENCODING: |
| raise UnsupportedContentEncoding( |
| f"unsupported content encoding: {content_encoding}" |
| ) |
| try: |
| return gzip.decompress(raw), True |
| except OSError as exc: |
| raise ValueError("invalid gzip request body") from exc |
|
|
|
|
| def _accepts_gzip(accept_encoding: str | None) -> bool: |
| for part in (accept_encoding or "").split(","): |
| normalized = part.replace(" ", "").lower() |
| token = normalized.split(";", 1)[0] |
| if token in (GZIP_ENCODING, "*") and "q=0" not in normalized: |
| return True |
| return False |
|
|
|
|
| def _should_gzip_response( |
| *, |
| accept_encoding: str | None, |
| request_was_gzip: bool, |
| ) -> bool: |
| return request_was_gzip or _accepts_gzip(accept_encoding) |
|
|
|
|
| def _idempotency_ttl_seconds() -> int: |
| raw = os.environ.get(IDEMPOTENCY_TTL_ENV) |
| if raw is None: |
| return DEFAULT_IDEMPOTENCY_TTL_SECONDS |
| try: |
| parsed = int(raw) |
| except ValueError: |
| return DEFAULT_IDEMPOTENCY_TTL_SECONDS |
| return max(0, parsed) |
|
|
|
|
| def _idempotency_key_from_payload(payload: dict[str, Any]) -> str | None: |
| value = payload.get("idempotency_key") |
| return value if isinstance(value, str) and value else None |
|
|
|
|
| def _fingerprint_value(value: Any) -> Any: |
| if isinstance(value, dict): |
| return { |
| key: _fingerprint_value(item) |
| for key, item in sorted(value.items()) |
| if key not in {"request_id", "idempotency_key"} |
| } |
| if isinstance(value, list): |
| return [_fingerprint_value(item) for item in value] |
| return value |
|
|
|
|
| def _idempotency_fingerprint(route: str, payload: dict[str, Any]) -> str: |
| encoded = json.dumps( |
| { |
| "route": route, |
| "payload": _fingerprint_value(payload), |
| }, |
| ensure_ascii=False, |
| sort_keys=True, |
| separators=(",", ":"), |
| ) |
| return hashlib.sha256(encoded.encode("utf-8")).hexdigest() |
|
|
|
|
| def _walk_receipts(value: Any) -> list[dict[str, Any]]: |
| receipts: list[dict[str, Any]] = [] |
| if isinstance(value, dict): |
| receipt = value.get("receipt") |
| if isinstance(receipt, dict): |
| receipts.append(receipt) |
| for item in value.values(): |
| receipts.extend(_walk_receipts(item)) |
| elif isinstance(value, list): |
| for item in value: |
| receipts.extend(_walk_receipts(item)) |
| return receipts |
|
|
|
|
| def _mark_idempotency( |
| body: dict[str, Any], |
| *, |
| key: str, |
| fingerprint: str, |
| replayed: bool, |
| original_request_id: str | None, |
| ) -> dict[str, Any]: |
| ttl_seconds = _idempotency_ttl_seconds() |
| marked = copy.deepcopy(body) |
| marked["idempotency_key"] = key |
| marked["idempotency_mode"] = "in_memory_response_replay" |
| marked["idempotency_replayed"] = replayed |
| marked["idempotency_fingerprint"] = fingerprint |
| marked["idempotency_ttl_seconds"] = ttl_seconds |
| if original_request_id: |
| marked["idempotency_original_request_id"] = original_request_id |
| for receipt in _walk_receipts(marked): |
| if receipt.get("idempotency_key") in (None, key): |
| receipt["idempotency_key"] = key |
| receipt["idempotency_mode"] = "in_memory_response_replay" |
| receipt["idempotency_replayed"] = replayed |
| receipt["idempotency_fingerprint"] = fingerprint |
| receipt["idempotency_ttl_seconds"] = ttl_seconds |
| if original_request_id: |
| receipt["idempotency_original_request_id"] = original_request_id |
| return marked |
|
|
|
|
| def _cached_idempotency_body( |
| *, |
| route: str, |
| key: str, |
| fingerprint: str, |
| ) -> dict[str, Any] | None: |
| now = time.monotonic() |
| with _IDEMPOTENCY_LOCK: |
| expired = [ |
| cache_key for cache_key, record in _IDEMPOTENCY_CACHE.items() |
| if float(record["expires_at"]) <= now |
| ] |
| for cache_key in expired: |
| _IDEMPOTENCY_CACHE.pop(cache_key, None) |
| prefix = (route, key) |
| for cache_key, record in _IDEMPOTENCY_CACHE.items(): |
| if cache_key[:2] != prefix: |
| continue |
| if record["fingerprint"] != fingerprint: |
| raise HTTPException(status_code=409, detail={ |
| "schema_version": API_SCHEMA_VERSION, |
| "status": "error", |
| "error": "idempotency_key_reused_with_different_payload", |
| "idempotency_key": key, |
| "idempotency_fingerprint": fingerprint, |
| "idempotency_existing_fingerprint": record["fingerprint"], |
| }) |
| return _mark_idempotency( |
| record["body"], |
| key=key, |
| fingerprint=fingerprint, |
| replayed=True, |
| original_request_id=record.get("request_id"), |
| ) |
| return None |
|
|
|
|
| def _store_idempotency_body( |
| *, |
| route: str, |
| key: str, |
| fingerprint: str, |
| body: dict[str, Any], |
| request_id: str | None, |
| ) -> dict[str, Any]: |
| marked = _mark_idempotency( |
| body, |
| key=key, |
| fingerprint=fingerprint, |
| replayed=False, |
| original_request_id=request_id, |
| ) |
| ttl_seconds = _idempotency_ttl_seconds() |
| if ttl_seconds > 0: |
| with _IDEMPOTENCY_LOCK: |
| _IDEMPOTENCY_CACHE[(route, key, fingerprint)] = { |
| "fingerprint": fingerprint, |
| "body": marked, |
| "request_id": request_id, |
| "expires_at": time.monotonic() + ttl_seconds, |
| } |
| return marked |
|
|
|
|
| def _handle_compress_with_idempotency(payload: dict[str, Any]) -> dict[str, Any]: |
| key = _idempotency_key_from_payload(payload) |
| if not key: |
| return _dispatch_compress(payload) |
| route = "/v1/compress" |
| fingerprint = _idempotency_fingerprint(route, payload) |
| cached = _cached_idempotency_body( |
| route=route, |
| key=key, |
| fingerprint=fingerprint, |
| ) |
| if cached is not None: |
| return cached |
| body = _dispatch_compress(payload) |
| request_id = payload.get("request_id") if isinstance(payload.get("request_id"), str) else None |
| return _store_idempotency_body( |
| route=route, |
| key=key, |
| fingerprint=fingerprint, |
| body=body, |
| request_id=request_id, |
| ) |
|
|
|
|
| async def _gzip_response_if_needed( |
| request: Request, |
| response: Response, |
| *, |
| request_was_gzip: bool, |
| ) -> Response: |
| should_gzip = _should_gzip_response( |
| accept_encoding=request.headers.get(ACCEPT_ENCODING_HEADER), |
| request_was_gzip=request_was_gzip, |
| ) |
| if not should_gzip or response.headers.get(CONTENT_ENCODING_HEADER): |
| return response |
|
|
| if hasattr(response, "body_iterator"): |
| body = b"" |
| async for chunk in response.body_iterator: |
| body += chunk |
| else: |
| body = getattr(response, "body", b"") |
| compressed = gzip.compress(body) |
| headers = { |
| key: value |
| for key, value in response.headers.items() |
| if key.lower() not in ("content-length", "content-encoding") |
| } |
| headers[CONTENT_ENCODING_HEADER] = GZIP_ENCODING |
| headers["Vary"] = ACCEPT_ENCODING_HEADER |
| headers["Content-Length"] = str(len(compressed)) |
| return Response( |
| content=compressed, |
| status_code=response.status_code, |
| headers=headers, |
| background=getattr(response, "background", None), |
| ) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def _get_tokenizer(): |
| from transformers import AutoTokenizer |
|
|
| return AutoTokenizer.from_pretrained(CLASSIFIER_MODEL) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def _get_count_tokenizer(model_name: str): |
| from transformers import AutoTokenizer |
|
|
| return AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def _classifier_manifest() -> dict[str, Any] | None: |
| if not CLASSIFIER_ARTIFACT_DIR: |
| return None |
| path = Path(CLASSIFIER_ARTIFACT_DIR) / "classifier_manifest.json" |
| if not path.exists(): |
| return { |
| "status": "manifest_error", |
| "error": f"missing classifier manifest: {path}", |
| } |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except Exception as exc: |
| return {"status": "manifest_error", "error": str(exc)} |
| if not isinstance(payload, dict): |
| return {"status": "manifest_error", "error": "manifest must be an object"} |
| return payload |
|
|
|
|
| @lru_cache(maxsize=1) |
| def _get_artifact_tokenizer(): |
| from transformers import AutoTokenizer |
|
|
| return AutoTokenizer.from_pretrained(str(CLASSIFIER_ARTIFACT_DIR)) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def _get_onnx_session(model_path: str): |
| import onnxruntime as ort |
|
|
| return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
|
|
|
|
| def _find_spans(text: str, needle: str) -> list[tuple[int, int]]: |
| spans = [] |
| cursor = 0 |
| while needle: |
| index = text.find(needle, cursor) |
| if index < 0: |
| return spans |
| spans.append((index, index + len(needle))) |
| cursor = index + len(needle) |
| return spans |
|
|
|
|
| def _regex_spans(text: str, pattern: str, flags: int) -> list[tuple[int, int]]: |
| return [(match.start(), match.end()) for match in re.finditer(pattern, text, flags)] |
|
|
|
|
| def _json_line_spans(text: str) -> list[tuple[int, int]]: |
| spans = [] |
| cursor = 0 |
| for line in text.splitlines(keepends=True): |
| stripped = line.strip() |
| if stripped and stripped[0] in "[{": |
| try: |
| import json |
|
|
| json.loads(stripped) |
| except Exception: |
| pass |
| else: |
| offset = line.index(stripped[0]) |
| spans.append((cursor + offset, cursor + offset + len(stripped))) |
| cursor += len(line) |
| return spans |
|
|
|
|
| def _merge(spans: list[tuple[int, int]]) -> list[tuple[int, int]]: |
| if not spans: |
| return [] |
| ordered = sorted(spans) |
| merged = [ordered[0]] |
| for start, end in ordered[1:]: |
| prev_start, prev_end = merged[-1] |
| if start <= prev_end: |
| merged[-1] = (prev_start, max(prev_end, end)) |
| else: |
| merged.append((start, end)) |
| return merged |
|
|
|
|
| def _overlaps(spans: list[tuple[int, int]], start: int, end: int) -> bool: |
| return any(start < span_end and end > span_start for span_start, span_end in spans) |
|
|
|
|
| def _manifest_drop_labels(text: str) -> tuple[str, list[dict[str, Any]], str | None]: |
| manifest = _classifier_manifest() |
| if not manifest: |
| return "tokenizer_loaded_no_trained_keep_drop_head", [], None |
| if manifest.get("status") == "manifest_error": |
| return "manifest_error", [], str(manifest.get("error")) |
| onnx_model = Path(CLASSIFIER_ARTIFACT_DIR or "") / str( |
| manifest.get("onnx_model", "model.onnx") |
| ) |
| if onnx_model.exists(): |
| try: |
| return ( |
| str(manifest.get("status") or "trained_keep_drop_head"), |
| _onnx_labels(text, manifest, onnx_model), |
| None, |
| ) |
| except Exception as exc: |
| return "local_artifact_error", [], str(exc) |
| labels: list[dict[str, Any]] = [] |
| for index, entry in enumerate(manifest.get("drop_phrases") or []): |
| if isinstance(entry, str): |
| phrase = entry |
| score = 1.0 |
| elif isinstance(entry, dict): |
| phrase = str(entry.get("text") or "") |
| score = float(entry.get("score", 1.0)) |
| else: |
| continue |
| cursor = 0 |
| while phrase: |
| start = text.find(phrase, cursor) |
| if start < 0: |
| break |
| end = start + len(phrase) |
| labels.append({ |
| "token": phrase, |
| "label": "DROP", |
| "score": score, |
| "start": start, |
| "end": end, |
| "source": "manifest_drop_phrase", |
| "classifier_index": index, |
| }) |
| cursor = end |
| return str(manifest.get("status") or "artifact_manifest_loaded"), labels, None |
|
|
|
|
| def _softmax(values: list[float]) -> list[float]: |
| peak = max(values) |
| exps = [math.exp(value - peak) for value in values] |
| total = sum(exps) or 1.0 |
| return [value / total for value in exps] |
|
|
|
|
| def _onnx_labels( |
| text: str, |
| manifest: dict[str, Any], |
| model_path: Path, |
| ) -> list[dict[str, Any]]: |
| import numpy as np |
|
|
| tokenizer = _get_artifact_tokenizer() |
| encoded = tokenizer( |
| text, |
| add_special_tokens=False, |
| return_offsets_mapping=True, |
| return_tensors="np", |
| truncation=True, |
| max_length=int(manifest.get("max_length", 512)), |
| stride=int(manifest.get("stride", 0)), |
| return_overflowing_tokens=True, |
| padding=True, |
| ) |
| offsets = encoded.pop("offset_mapping") |
| input_ids = encoded["input_ids"] |
| session = _get_onnx_session(str(model_path)) |
| input_names = {item.name for item in session.get_inputs()} |
| inputs = {key: value for key, value in encoded.items() if key in input_names} |
| logits = session.run(None, inputs)[0] |
| id2label = { |
| str(key): value |
| for key, value in (manifest.get("id2label") or {"0": "KEEP", "1": "DROP"}).items() |
| } |
| label2id = {str(value).upper(): int(key) for key, value in id2label.items()} |
| keep_id = label2id.get("KEEP", 0) |
| drop_id = label2id.get("DROP", 1) |
| drop_score_threshold = float(manifest.get("drop_score_threshold", 0.5)) |
| best_by_span: dict[tuple[int, int], dict[str, Any]] = {} |
| for chunk_index, chunk_logits in enumerate(logits): |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[chunk_index]) |
| for token_index, token_logits in enumerate(chunk_logits): |
| probs = _softmax(np.asarray(token_logits, dtype=float).tolist()) |
| keep_score = float(probs[keep_id]) if keep_id < len(probs) else 0.0 |
| drop_score = float(probs[drop_id]) if drop_id < len(probs) else 0.0 |
| label_id = drop_id if drop_score >= drop_score_threshold else keep_id |
| start, end = offsets[chunk_index][token_index] |
| start = int(start) |
| end = int(end) |
| if end <= start: |
| continue |
| key = (start, end) |
| item = { |
| "token": tokens[token_index], |
| "label": id2label.get(str(label_id), "KEEP"), |
| "score": round(drop_score if label_id == drop_id else keep_score, 6), |
| "keep_score": round(keep_score, 6), |
| "drop_score": round(drop_score, 6), |
| "drop_score_threshold": drop_score_threshold, |
| "start": start, |
| "end": end, |
| "source": "onnx_token_classifier", |
| "chunk_index": chunk_index, |
| } |
| if drop_score > float(best_by_span.get(key, {}).get("drop_score", -1.0)): |
| best_by_span[key] = item |
| return [best_by_span[key] for key in sorted(best_by_span)] |
|
|
|
|
| def _safe_classifier_drop_ranges( |
| labels: list[dict[str, Any]], |
| protected: list[tuple[int, int]], |
| *, |
| text_length: int, |
| min_score: float, |
| ) -> tuple[list[tuple[int, int]], int, int, dict[str, int]]: |
| blocked_by_reason = { |
| "malformed_or_missing_offsets": 0, |
| "out_of_bounds": 0, |
| "below_min_score": 0, |
| "protected_span_overlap": 0, |
| } |
| ranges: list[tuple[int, int]] = [] |
| drop_labels = 0 |
| for item in labels: |
| raw_drop_score = item.get("drop_score") |
| try: |
| drop_score = float(raw_drop_score) if raw_drop_score is not None else None |
| except (TypeError, ValueError): |
| drop_score = None |
| if ( |
| str(item.get("label") or item.get("entity") or "").upper() != "DROP" |
| and (drop_score is None or drop_score < min_score) |
| ): |
| continue |
| drop_labels += 1 |
| try: |
| start = int(item["start"]) |
| end = int(item["end"]) |
| score = drop_score if drop_score is not None else float(item.get("score", 1.0)) |
| except Exception: |
| blocked_by_reason["malformed_or_missing_offsets"] += 1 |
| continue |
| if start < 0 or end > text_length or end <= start: |
| blocked_by_reason["out_of_bounds"] += 1 |
| continue |
| if score < min_score: |
| blocked_by_reason["below_min_score"] += 1 |
| continue |
| if _overlaps(protected, start, end): |
| blocked_by_reason["protected_span_overlap"] += 1 |
| continue |
| ranges.append((start, end)) |
| merged = _merge(ranges) |
| return merged, drop_labels, len(merged), blocked_by_reason |
|
|
|
|
| def _is_subsequence(candidate: str, original: str) -> bool: |
| cursor = 0 |
| for char in candidate: |
| cursor = original.find(char, cursor) |
| if cursor < 0: |
| return False |
| cursor += 1 |
| return True |
|
|
|
|
| def _sha256_text(value: str) -> str: |
| return hashlib.sha256(value.encode("utf-8")).hexdigest() |
|
|
|
|
| def _receipt_id(payload: dict[str, Any]) -> str: |
| encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) |
| return "tdcr_" + hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:24] |
|
|
|
|
| def _stable_json_sha256(value: Any) -> str: |
| encoded = json.dumps( |
| value, |
| ensure_ascii=False, |
| sort_keys=True, |
| separators=(",", ":"), |
| ) |
| return hashlib.sha256(encoded.encode("utf-8")).hexdigest() |
|
|
|
|
| def _aggregate_receipt_id(payload: dict[str, Any]) -> str: |
| encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) |
| return "tdcm_" + hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:24] |
|
|
|
|
| def _string_set( |
| value: Any, |
| *, |
| default: tuple[str, ...] = (), |
| field_name: str, |
| ) -> set[str]: |
| if value is None: |
| return set(default) |
| if ( |
| isinstance(value, list) |
| and all(isinstance(item, str) and item for item in value) |
| ): |
| return {item.lower() for item in value} |
| raise HTTPException(status_code=400, detail=f"{field_name} must be a list of strings") |
|
|
|
|
| def _message_role(message: dict[str, Any], index: int) -> str: |
| role = message.get("role") |
| if not isinstance(role, str) or not role: |
| raise HTTPException( |
| status_code=400, |
| detail=f"messages[{index}].role must be a string", |
| ) |
| return role.lower() |
|
|
|
|
| def _message_decision( |
| *, |
| tokens_saved: int, |
| receipts: list[dict[str, Any]], |
| ) -> str: |
| decisions = [ |
| receipt.get("decision") |
| for receipt in receipts |
| if isinstance(receipt, dict) |
| ] |
| if any(decision == "reject" for decision in decisions): |
| return "reject" |
| if any(decision == "needs_review" for decision in decisions): |
| return "needs_review" |
| if tokens_saved <= 0: |
| return "no_op" |
| return "high_confidence" |
|
|
|
|
| def _correlation_payload( |
| payload: dict[str, Any], |
| *, |
| request_id: str | None = None, |
| idempotency_key: str | None = None, |
| ) -> dict[str, Any]: |
| out = dict(payload) |
| if request_id and "request_id" not in out: |
| out["request_id"] = request_id |
| if idempotency_key and "idempotency_key" not in out: |
| out["idempotency_key"] = idempotency_key |
| return out |
|
|
|
|
| def _tool_schema_variants(schema: Any) -> list[str]: |
| if isinstance(schema, str): |
| return [schema] if schema else [] |
| if not isinstance(schema, (dict, list)): |
| raise HTTPException( |
| status_code=400, |
| detail="tool_schemas entries must be strings, objects, or arrays", |
| ) |
| variants: list[str] = [] |
| for kwargs in ( |
| {"ensure_ascii": False, "sort_keys": False, "separators": (",", ":")}, |
| {"ensure_ascii": False, "sort_keys": True, "separators": (",", ":")}, |
| {"ensure_ascii": False, "sort_keys": False, "indent": 2}, |
| {"ensure_ascii": False, "sort_keys": True, "indent": 2}, |
| ): |
| encoded = json.dumps(schema, **kwargs) |
| if encoded not in variants: |
| variants.append(encoded) |
| return variants |
|
|
|
|
| def _tool_schema_groups(tool_schemas: Any) -> list[list[str]]: |
| if tool_schemas is None: |
| return [] |
| if not isinstance(tool_schemas, list): |
| raise HTTPException(status_code=400, detail="tool_schemas must be a list") |
| return [_tool_schema_variants(schema) for schema in tool_schemas] |
|
|
|
|
| def _schema_occurrences(text: str, group: list[str]) -> int: |
| return sum(len(_find_spans(text, candidate)) for candidate in group) |
|
|
|
|
| def _tool_schema_missing_groups( |
| input_text: str, |
| output: str, |
| groups: list[list[str]], |
| ) -> list[list[str]]: |
| missing = [] |
| for group in groups: |
| input_occurrences = _schema_occurrences(input_text, group) |
| output_occurrences = _schema_occurrences(output, group) |
| if input_occurrences > 0 and output_occurrences == 0: |
| missing.append(group) |
| return missing |
|
|
|
|
| def _optional_string(payload: dict[str, Any], key: str) -> str | None: |
| value = payload.get(key) |
| if value is None: |
| return None |
| if not isinstance(value, str): |
| raise HTTPException(status_code=400, detail=f"{key} must be a string") |
| return value |
|
|
|
|
| def _count_tokens(text: str, tokenizer_model: str | None) -> tuple[int, bool, str]: |
| if tokenizer_model: |
| try: |
| tokenizer = _get_count_tokenizer(tokenizer_model) |
| return ( |
| len(tokenizer.encode(text, add_special_tokens=False)), |
| True, |
| "tokenizer", |
| ) |
| except Exception: |
| pass |
| return max(1, round(len(text) / 4.0)), False, "chars_per_token_estimate" |
|
|
|
|
| def _protected_spans( |
| text: str, |
| protected_values: list[str], |
| tool_schema_groups: list[list[str]], |
| *, |
| preserve_code: bool, |
| preserve_json: bool, |
| ) -> tuple[ |
| list[tuple[int, int]], |
| list[tuple[int, int]], |
| list[tuple[int, int]], |
| list[tuple[int, int]], |
| list[tuple[int, int]], |
| ]: |
| spans = [] |
| for value in protected_values: |
| spans.extend(_find_spans(text, value)) |
| tool_schema_spans = [] |
| for group in tool_schema_groups: |
| for candidate in group: |
| tool_schema_spans.extend(_find_spans(text, candidate)) |
| spans.extend(tool_schema_spans) |
| code_spans = ( |
| _regex_spans(text, r"```.*?```", re.MULTILINE | re.DOTALL) |
| if preserve_code |
| else [] |
| ) |
| json_spans = ( |
| _regex_spans(text, r"```(?:json|JSON)\s*\n.*?```", re.MULTILINE | re.DOTALL) |
| + _json_line_spans(text) |
| if preserve_json |
| else [] |
| ) |
| system_spans = ( |
| _regex_spans(text, r"<system\b[^>]*>.*?</system>", re.MULTILINE | re.DOTALL) |
| + _regex_spans(text, r"^system\s*:\s*.*$", re.MULTILINE | re.IGNORECASE) |
| ) |
| spans.extend(code_spans) |
| spans.extend(json_spans) |
| spans.extend(system_spans) |
| return ( |
| _merge(spans), |
| code_spans, |
| json_spans, |
| system_spans, |
| _merge(tool_schema_spans), |
| ) |
|
|
|
|
| def _compress_text(payload: dict[str, Any]) -> dict[str, Any]: |
| text = payload.get("input") |
| if not isinstance(text, str): |
| raise HTTPException(status_code=400, detail="input must be a string") |
| settings = payload.get("compression_settings") or {} |
| if not isinstance(settings, dict): |
| raise HTTPException(status_code=400, detail="compression_settings must be an object") |
| request_id = payload.get("request_id") |
| if request_id is not None and not isinstance(request_id, str): |
| raise HTTPException(status_code=400, detail="request_id must be a string") |
| idempotency_key = payload.get("idempotency_key") |
| if idempotency_key is not None and not isinstance(idempotency_key, str): |
| raise HTTPException(status_code=400, detail="idempotency_key must be a string") |
| tokenizer_model = _optional_string(payload, "tokenizer_model") |
| protected_values = payload.get("protected_spans") or [] |
| if not isinstance(protected_values, list) or not all( |
| isinstance(value, str) for value in protected_values |
| ): |
| raise HTTPException(status_code=400, detail="protected_spans must be strings") |
| tool_schema_groups = _tool_schema_groups( |
| payload.get("tool_schemas", payload.get("tools")) |
| ) |
| try: |
| aggressiveness = float(settings.get("aggressiveness", 0.35)) |
| except Exception as exc: |
| raise HTTPException(status_code=400, detail="aggressiveness must be numeric") from exc |
| if not 0.0 <= aggressiveness <= 1.0: |
| raise HTTPException(status_code=400, detail="aggressiveness must be 0..1") |
|
|
| started = time.perf_counter() |
| preserve_code = bool(settings.get("preserve_code", True)) |
| preserve_json = bool(settings.get("preserve_json", True)) |
| protected, code_spans, json_spans, system_spans, tool_schema_spans = _protected_spans( |
| text, |
| protected_values, |
| tool_schema_groups, |
| preserve_code=preserve_code, |
| preserve_json=preserve_json, |
| ) |
| classifier_status, classifier_labels, classifier_error = _manifest_drop_labels(text) |
| classifier_ranges, classifier_drop_labels, classifier_applied, classifier_blocked_by_reason = ( |
| _safe_classifier_drop_ranges( |
| classifier_labels, |
| protected, |
| text_length=len(text), |
| min_score=float(settings.get("classifier_min_drop_score", 0.5)), |
| ) |
| ) |
| seen: set[str] = set() |
| drops: list[tuple[int, int, str, str]] = [] |
| cursor = 0 |
| previous_blank_kept = False |
| for line in text.splitlines(keepends=True): |
| start = cursor |
| end = cursor + len(line) |
| cursor = end |
| stripped = line.strip() |
| normalized = re.sub(r"\s+", " ", stripped).lower() |
| if _overlaps(protected, start, end): |
| if normalized: |
| seen.add(normalized) |
| previous_blank_kept = False |
| continue |
| reason = None |
| if not stripped and previous_blank_kept: |
| reason = "extra_blank" |
| elif len(normalized) > 30 and normalized in seen and aggressiveness >= 0.10: |
| reason = "duplicate_line" |
| elif aggressiveness >= 0.20 and any( |
| pattern.search(stripped) for pattern in LOW_SIGNAL_PATTERNS |
| ): |
| reason = "low_signal_phrase" |
| if reason: |
| drops.append((start, end, reason, stripped[:120])) |
| continue |
| previous_blank_kept = not stripped |
| if normalized: |
| seen.add(normalized) |
|
|
| chunks = [] |
| drop_ranges = _merge([(start, end) for start, end, _, _ in drops] + classifier_ranges) |
| cursor = 0 |
| for start, end in drop_ranges: |
| chunks.append(text[cursor:start]) |
| cursor = end |
| chunks.append(text[cursor:]) |
| output = "".join(chunks) |
| before, before_exact, token_method = _count_tokens(text, tokenizer_model) |
| after, after_exact, after_method = _count_tokens(output, tokenizer_model) |
| token_count_exact = before_exact and after_exact |
| if after_method != token_method: |
| token_method = "chars_per_token_estimate" |
| saved = max(0, before - after) |
| missing = [value for value in protected_values if value and value not in output] |
| code_preserved = all(text[start:end] in output for start, end in code_spans) |
| json_preserved = all(text[start:end] in output for start, end in json_spans) |
| system_preserved = all(text[start:end] in output for start, end in system_spans) |
| missing_tool_schemas = _tool_schema_missing_groups( |
| text, |
| output, |
| tool_schema_groups, |
| ) |
| tool_schemas_preserved = not missing_tool_schemas |
| decision = "reject" if ( |
| missing |
| or not code_preserved |
| or not json_preserved |
| or not system_preserved |
| or not tool_schemas_preserved |
| ) else ( |
| "high_confidence" if saved > 0 and aggressiveness <= 0.65 else "no_op" |
| ) |
| dropped_segments = [ |
| {"reason": reason, "preview": preview, "start": start, "end": end} |
| for start, end, reason, preview in drops[:20] |
| ] |
| classifier_blocked = sum(classifier_blocked_by_reason.values()) |
| receipt = { |
| "request_id": request_id, |
| "idempotency_key": idempotency_key, |
| "idempotency_mode": "correlation_only_no_response_replay", |
| "protected_spans_checked": len(protected_values), |
| "protected_spans_missing": len(missing), |
| "code_blocks_detected": len(code_spans), |
| "code_blocks_preserved": code_preserved, |
| "json_blocks_detected": len(json_spans), |
| "json_blocks_preserved": json_preserved, |
| "system_prompt_spans_detected": len(system_spans), |
| "system_prompts_preserved": system_preserved, |
| "tool_schemas_checked": len(tool_schema_groups), |
| "tool_schema_occurrences": sum( |
| _schema_occurrences(text, group) |
| for group in tool_schema_groups |
| ), |
| "tool_schema_spans_detected": len(tool_schema_spans), |
| "tool_schemas_missing": len(missing_tool_schemas), |
| "tool_schemas_missing_preview": [ |
| next((candidate[:120] for candidate in group if candidate), "") |
| for group in missing_tool_schemas |
| ], |
| "tool_schemas_preserved": tool_schemas_preserved, |
| "decision": decision, |
| "compressor_latency_ms": round((time.perf_counter() - started) * 1000.0, 3), |
| "deletion_only": _is_subsequence(output, text), |
| "deterministic": True, |
| "rules_version": RULES_VERSION, |
| "classifier": { |
| "model": CLASSIFIER_MODEL, |
| "status": classifier_status, |
| "artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR), |
| "artifact_dir": CLASSIFIER_ARTIFACT_DIR, |
| "error": classifier_error, |
| "labels_received": len(classifier_labels), |
| "drop_labels": classifier_drop_labels, |
| "drop_spans_applied": classifier_applied, |
| "drop_spans_blocked_by_safety": classifier_blocked, |
| "drop_spans_blocked_by_reason": classifier_blocked_by_reason, |
| }, |
| "removed_spans_count": len(drop_ranges), |
| "removed_chars": sum(end - start for start, end in drop_ranges), |
| "rule_drop_spans_applied": len(drops), |
| "rule_drop_chars": sum(end - start for start, end, _, _ in drops), |
| "classifier_drop_spans_applied": len(classifier_ranges), |
| "classifier_drop_chars": sum(end - start for start, end in classifier_ranges), |
| "dropped_segments_count": len(drops), |
| "dropped_segments": dropped_segments, |
| "token_count_exact": token_count_exact, |
| "token_count_method": token_method, |
| "tokenizer_model": tokenizer_model, |
| } |
| receipt["input_sha256"] = _sha256_text(text) |
| receipt["output_sha256"] = _sha256_text(output) |
| receipt["removed_sha256"] = _sha256_text( |
| "".join(text[start:end] for start, end in drop_ranges) |
| ) |
| receipt["receipt_id"] = _receipt_id({ |
| "input_sha256": receipt["input_sha256"], |
| "output_sha256": receipt["output_sha256"], |
| "removed_sha256": receipt["removed_sha256"], |
| "tokens_saved": saved, |
| "compression_percentage": round(100.0 * saved / before, 1), |
| "decision": decision, |
| "rules_version": RULES_VERSION, |
| "classifier": receipt["classifier"], |
| "dropped_segments": dropped_segments, |
| }) |
| return { |
| "schema_version": API_SCHEMA_VERSION, |
| "status": "ok", |
| "endpoint": "/v1/compress", |
| "maturity": "measurement_only", |
| "output": output, |
| "original_input_tokens": before, |
| "output_tokens": after, |
| "tokens_saved": saved, |
| "compression_percentage": round(100.0 * saved / before, 1), |
| "request_id": request_id, |
| "idempotency_key": idempotency_key, |
| "receipt": receipt, |
| } |
|
|
|
|
| def _merge_batch_item_payload( |
| payload: dict[str, Any], |
| item: Any, |
| index: int, |
| ) -> tuple[str | None, dict[str, Any]]: |
| if isinstance(item, str): |
| return None, { |
| "input": item, |
| "compression_settings": payload.get("compression_settings"), |
| "protected_spans": payload.get("protected_spans"), |
| "tool_schemas": payload.get("tool_schemas", payload.get("tools")), |
| "tokenizer_model": payload.get("tokenizer_model"), |
| "request_id": payload.get("request_id"), |
| "idempotency_key": payload.get("idempotency_key"), |
| } |
| if not isinstance(item, dict): |
| raise ValueError(f"inputs[{index}] must be a string or an object") |
|
|
| item_id = item.get("id") |
| if item_id is not None and not isinstance(item_id, str): |
| raise ValueError(f"inputs[{index}].id must be a string") |
| if "input" not in item: |
| raise ValueError(f"inputs[{index}].input is required") |
|
|
| top_settings = payload.get("compression_settings") |
| item_settings = item.get("compression_settings") |
| if top_settings is not None and not isinstance(top_settings, dict): |
| raise ValueError("compression_settings must be an object") |
| if item_settings is not None and not isinstance(item_settings, dict): |
| raise ValueError(f"inputs[{index}].compression_settings must be an object") |
| settings = { |
| **(top_settings or {}), |
| **(item_settings or {}), |
| } or None |
|
|
| return item_id, { |
| "input": item.get("input"), |
| "compression_settings": settings, |
| "protected_spans": item.get("protected_spans", payload.get("protected_spans")), |
| "tool_schemas": item.get( |
| "tool_schemas", |
| item.get("tools", payload.get("tool_schemas", payload.get("tools"))), |
| ), |
| "tokenizer_model": item.get("tokenizer_model", payload.get("tokenizer_model")), |
| "request_id": item.get("request_id", payload.get("request_id")), |
| "idempotency_key": item.get( |
| "idempotency_key", |
| payload.get("idempotency_key"), |
| ), |
| } |
|
|
|
|
| def _handle_batch(payload: dict[str, Any]) -> dict[str, Any]: |
| if "input" in payload: |
| raise HTTPException(status_code=400, detail="provide either input or inputs, not both") |
| inputs = payload.get("inputs") |
| if not isinstance(inputs, list): |
| raise HTTPException(status_code=400, detail="inputs must be a list") |
| if not inputs: |
| raise HTTPException(status_code=400, detail="inputs list is empty") |
| request_id = payload.get("request_id") |
| if request_id is not None and not isinstance(request_id, str): |
| raise HTTPException(status_code=400, detail="request_id must be a string") |
| idempotency_key = payload.get("idempotency_key") |
| if idempotency_key is not None and not isinstance(idempotency_key, str): |
| raise HTTPException(status_code=400, detail="idempotency_key must be a string") |
|
|
| results: list[dict[str, Any]] = [] |
| totals = { |
| "original_input_tokens": 0, |
| "output_tokens": 0, |
| "tokens_saved": 0, |
| } |
| succeeded = 0 |
| failed = 0 |
| for index, item in enumerate(inputs): |
| item_result: dict[str, Any] = {"index": index} |
| if isinstance(item, dict) and isinstance(item.get("id"), str): |
| item_result["id"] = item["id"] |
| try: |
| item_id, item_payload = _merge_batch_item_payload(payload, item, index) |
| if item_id is not None: |
| item_result["id"] = item_id |
| result = _compress_text(item_payload) |
| except HTTPException as exc: |
| item_result.update({"status": "error", "error": str(exc.detail)}) |
| failed += 1 |
| results.append(item_result) |
| continue |
| except ValueError as exc: |
| item_result.update({"status": "error", "error": str(exc)}) |
| failed += 1 |
| results.append(item_result) |
| continue |
|
|
| item_result.update({"status": "ok", **result}) |
| totals["original_input_tokens"] += int(result["original_input_tokens"]) |
| totals["output_tokens"] += int(result["output_tokens"]) |
| totals["tokens_saved"] += int(result["tokens_saved"]) |
| succeeded += 1 |
| results.append(item_result) |
|
|
| compression_pct = ( |
| round(100.0 * totals["tokens_saved"] / totals["original_input_tokens"], 1) |
| if totals["original_input_tokens"] |
| else 0.0 |
| ) |
| receipt_ids = [ |
| result["receipt"]["receipt_id"] |
| for result in results |
| if result.get("status") == "ok" and result.get("receipt", {}).get("receipt_id") |
| ] |
| body = { |
| "schema_version": API_SCHEMA_VERSION, |
| "status": "ok" if failed == 0 else "partial_error", |
| "endpoint": "/v1/compress", |
| "maturity": "measurement_only", |
| "input_count": len(inputs), |
| "succeeded": succeeded, |
| "failed": failed, |
| **totals, |
| "compression_percentage": compression_pct, |
| "receipt_ids": receipt_ids, |
| "results": results, |
| } |
| if request_id is not None: |
| body["request_id"] = request_id |
| if idempotency_key is not None: |
| body["idempotency_key"] = idempotency_key |
| return body |
|
|
|
|
| def _handle_messages(payload: dict[str, Any]) -> dict[str, Any]: |
| if "input" in payload or "inputs" in payload: |
| raise HTTPException(status_code=400, detail="provide either messages, input, or inputs") |
| messages = payload.get("messages") |
| if not isinstance(messages, list) or not messages: |
| raise HTTPException(status_code=400, detail="messages must be a non-empty list") |
| if not all(isinstance(message, dict) for message in messages): |
| raise HTTPException(status_code=400, detail="messages entries must be objects") |
| settings = payload.get("compression_settings") or {} |
| if not isinstance(settings, dict): |
| raise HTTPException(status_code=400, detail="compression_settings must be an object") |
| protected_roles = _string_set( |
| settings.get("protected_roles"), |
| default=DEFAULT_PROTECTED_MESSAGE_ROLES, |
| field_name="compression_settings.protected_roles", |
| ) |
| compress_roles = ( |
| _string_set( |
| settings.get("compress_roles"), |
| field_name="compression_settings.compress_roles", |
| ) |
| if "compress_roles" in settings else None |
| ) |
| protected_values = payload.get("protected_spans") or [] |
| if not isinstance(protected_values, list) or not all( |
| isinstance(value, str) for value in protected_values |
| ): |
| raise HTTPException(status_code=400, detail="protected_spans must be strings") |
| request_id = payload.get("request_id") |
| if request_id is not None and not isinstance(request_id, str): |
| raise HTTPException(status_code=400, detail="request_id must be a string") |
| idempotency_key = payload.get("idempotency_key") |
| if idempotency_key is not None and not isinstance(idempotency_key, str): |
| raise HTTPException(status_code=400, detail="idempotency_key must be a string") |
| tokenizer_model = _optional_string(payload, "tokenizer_model") |
|
|
| output_messages: list[dict[str, Any]] = [] |
| receipts: list[dict[str, Any]] = [] |
| nested_receipts: list[dict[str, Any]] = [] |
| receipt_ids: list[str] = [] |
| original_tokens = 0 |
| output_tokens = 0 |
| compressed_message_count = 0 |
| skipped_message_count = 0 |
|
|
| for index, message in enumerate(messages): |
| role = _message_role(message, index) |
| content = message.get("content") |
| output_message = dict(message) |
| if not isinstance(content, str): |
| skipped_message_count += 1 |
| receipts.append({ |
| "index": index, |
| "role": role, |
| "status": "skipped", |
| "reason": "non_string_content", |
| "content_type": type(content).__name__, |
| "original_input_tokens": 0, |
| "output_tokens": 0, |
| "tokens_saved": 0, |
| }) |
| output_messages.append(output_message) |
| continue |
|
|
| role_protected = role in protected_roles or ( |
| compress_roles is not None and role not in compress_roles |
| ) |
| item_protected = list(protected_values) |
| if role_protected and content: |
| item_protected.append(content) |
| result = _compress_text({ |
| "input": content, |
| "compression_settings": settings, |
| "protected_spans": item_protected, |
| "tool_schemas": payload.get("tool_schemas", payload.get("tools")), |
| "tokenizer_model": tokenizer_model, |
| "request_id": request_id, |
| "idempotency_key": idempotency_key, |
| }) |
| output_message["content"] = result["output"] |
| output_messages.append(output_message) |
| original_tokens += int(result["original_input_tokens"]) |
| output_tokens += int(result["output_tokens"]) |
| tokens_saved = int(result["tokens_saved"]) |
| if tokens_saved > 0: |
| compressed_message_count += 1 |
| receipt = result["receipt"] |
| nested_receipts.append(receipt) |
| receipt_ids.append(receipt["receipt_id"]) |
| receipts.append({ |
| "index": index, |
| "role": role, |
| "status": "ok", |
| "protected_by_role": role_protected, |
| "original_input_tokens": result["original_input_tokens"], |
| "output_tokens": result["output_tokens"], |
| "tokens_saved": tokens_saved, |
| "compression_percentage": result["compression_percentage"], |
| "receipt_id": receipt["receipt_id"], |
| "receipt": receipt, |
| }) |
|
|
| tokens_saved_total = max(0, original_tokens - output_tokens) |
| compression_pct = ( |
| round(100.0 * tokens_saved_total / original_tokens, 1) |
| if original_tokens else 0.0 |
| ) |
| decision = _message_decision( |
| tokens_saved=tokens_saved_total, |
| receipts=nested_receipts, |
| ) |
| aggregate_receipt = { |
| "receipt_version": "message-compression-receipt-v0.1.0", |
| "receipt_id": _aggregate_receipt_id({ |
| "input_sha256": _stable_json_sha256(messages), |
| "output_sha256": _stable_json_sha256(output_messages), |
| "receipt_ids": receipt_ids, |
| "tokens_saved": tokens_saved_total, |
| "compression_percentage": compression_pct, |
| "decision": decision, |
| }), |
| "request_id": request_id, |
| "idempotency_key": idempotency_key, |
| "message_count": len(messages), |
| "compressed_message_count": compressed_message_count, |
| "skipped_message_count": skipped_message_count, |
| "protected_roles": sorted(protected_roles), |
| "compress_roles": sorted(compress_roles) if compress_roles is not None else None, |
| "decision": decision, |
| "deletion_only": all( |
| receipt.get("deletion_only", True) for receipt in nested_receipts |
| ), |
| "deterministic": True, |
| "input_sha256": _stable_json_sha256(messages), |
| "output_sha256": _stable_json_sha256(output_messages), |
| "message_receipt_ids": receipt_ids, |
| } |
| compressed_prompt: dict[str, Any] = { |
| "messages": output_messages, |
| "protected_spans": protected_values, |
| } |
| if "tools" in payload: |
| compressed_prompt["tools"] = payload["tools"] |
| if "tool_schemas" in payload: |
| compressed_prompt["tool_schemas"] = payload["tool_schemas"] |
| body = { |
| "schema_version": API_SCHEMA_VERSION, |
| "status": "ok", |
| "endpoint": "/v1/compress", |
| "maturity": "measurement_only", |
| "messages": output_messages, |
| "compressed_messages": output_messages, |
| "compressed_prompts": [compressed_prompt], |
| "message_count": len(messages), |
| "compressed_message_count": compressed_message_count, |
| "skipped_message_count": skipped_message_count, |
| "original_input_tokens": original_tokens, |
| "output_tokens": output_tokens, |
| "tokens_saved": tokens_saved_total, |
| "compression_percentage": compression_pct, |
| "receipt_ids": receipt_ids, |
| "receipts": receipts, |
| "receipt": aggregate_receipt, |
| } |
| if request_id is not None: |
| body["request_id"] = request_id |
| if idempotency_key is not None: |
| body["idempotency_key"] = idempotency_key |
| return body |
|
|
|
|
| def _dispatch_compress(payload: dict[str, Any]) -> dict[str, Any]: |
| if "inputs" in payload: |
| return _handle_batch(payload) |
| if "messages" in payload: |
| return _handle_messages(payload) |
| return _compress_text(payload) |
|
|
|
|
| def _tokens(text: str) -> list[dict[str, Any]]: |
| started = time.perf_counter() |
| try: |
| tokenizer = _get_tokenizer() |
| encoded = tokenizer( |
| text, |
| add_special_tokens=False, |
| return_offsets_mapping=True, |
| ) |
| pieces = tokenizer.convert_ids_to_tokens(encoded["input_ids"]) |
| offsets = encoded["offset_mapping"] |
| token_source = "deberta_tokenizer" |
| except Exception: |
| matches = list(re.finditer(r"\S+|\s+", text)) |
| pieces = [match.group(0) for match in matches] |
| offsets = [(match.start(), match.end()) for match in matches] |
| token_source = "regex_fallback" |
|
|
| elapsed = round((time.perf_counter() - started) * 1000.0, 3) |
| return [ |
| { |
| "token": piece, |
| "label": "KEEP", |
| "score": 1.0, |
| "start": int(offsets[i][0]), |
| "end": int(offsets[i][1]), |
| "source": token_source, |
| "classifier_latency_ms": elapsed if i == 0 else 0.0, |
| } |
| for i, piece in enumerate(pieces) |
| ] |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, Any]: |
| classifier_status, _, classifier_error = _manifest_drop_labels("") |
| return { |
| "status": "ok", |
| "classifier_model": CLASSIFIER_MODEL, |
| "classifier_status": classifier_status, |
| "classifier_artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR), |
| "classifier_error": classifier_error, |
| "phase": "rules_api_first", |
| "transport": { |
| "request_id_header": REQUEST_ID_HEADER, |
| "idempotency_key_header": IDEMPOTENCY_KEY_HEADER, |
| "content_encoding_fallback_header": TOUCHDOWN_CONTENT_ENCODING_HEADER, |
| "gzip_request_bodies": True, |
| "gzip_responses": True, |
| "idempotency_mode": "in_memory_response_replay", |
| "idempotency_ttl_seconds": _idempotency_ttl_seconds(), |
| }, |
| } |
|
|
|
|
| @app.post("/v1/classify") |
| def classify(payload: dict[str, Any]) -> dict[str, Any]: |
| text = payload.get("input") |
| if not isinstance(text, str): |
| raise HTTPException(status_code=400, detail="input must be a string") |
| classifier_status, artifact_labels, classifier_error = _manifest_drop_labels(text) |
| labels = artifact_labels or _tokens(text) |
| return { |
| "model": CLASSIFIER_MODEL, |
| "task": "token_classification_keep_drop", |
| "status": ( |
| classifier_status if CLASSIFIER_ARTIFACT_DIR or artifact_labels or classifier_error |
| else "tokenizer_only_until_trained_head" |
| ), |
| "artifact_dir_configured": bool(CLASSIFIER_ARTIFACT_DIR), |
| "error": classifier_error, |
| "labels": labels, |
| "note": ( |
| "This free Space scaffold loads the DeBERTa tokenizer and returns " |
| "KEEP-only labels until a trained KEEP/DROP head or ONNX export is " |
| "mounted. Do not treat it as a production compressor." |
| ), |
| } |
|
|
|
|
| @app.post("/v1/compress") |
| async def compress(request: Request) -> dict[str, Any]: |
| raw = await request.body() |
| try: |
| body, _ = _decode_request_body( |
| raw, |
| request.headers.get(CONTENT_ENCODING_HEADER) |
| or request.headers.get(TOUCHDOWN_CONTENT_ENCODING_HEADER), |
| ) |
| except UnsupportedContentEncoding as exc: |
| raise HTTPException(status_code=415, detail=str(exc)) from exc |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| try: |
| payload = json.loads(body.decode("utf-8")) if body else {} |
| except Exception as exc: |
| raise HTTPException(status_code=400, detail="invalid JSON body") from exc |
| if not isinstance(payload, dict): |
| raise HTTPException(status_code=400, detail="request body must be a JSON object") |
| payload = _correlation_payload( |
| payload, |
| request_id=getattr(request.state, "request_id", None), |
| idempotency_key=request.headers.get(IDEMPOTENCY_KEY_HEADER), |
| ) |
| return _handle_compress_with_idempotency(payload) |
|
|